Skip to content

Commit 04ed89c

Browse files
authored
Add support for MLIR 18 (#79)
* Add missing ops of LLVM dialect * Register dialect files for LLVM 18 * Format code * Update LLVM version to v18.1.7+2 for binding generation * Fix ArmSVE file path in MLIR 18 * Fix tblgen to julia generation on multiple optional named operands * Minor fixes on `mlir_jl_tblgen` * Sanitize operand names on usage * Generate MLIR 18 dialect bindings * Include `v18` as a valid MLIR versioned module * Format code
1 parent c72da79 commit 04ed89c

File tree

184 files changed

+173543
-34337
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

184 files changed

+173543
-34337
lines changed

bindings/make.jl

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ function mlir_dialects(version::VersionNumber)
2929
"Linalg.jl",
3030
["Linalg/IR/LinalgOps.td", "Linalg/IR/LinalgStructuredOps.td"],
3131
),
32-
("llvm", "LLVMIR.jl", ["LLVMIR/LLVMOps.td"]),
32+
(
33+
"llvm",
34+
"LLVMIR.jl",
35+
["LLVMIR/LLVMOps.td", "LLVMIR/NVVMOps.td", "LLVMIR/ROCDLOps.td"],
36+
),
3337
("math", "Math.jl", ["Math/IR/MathOps.td"]),
3438
("memref", "MemRef.jl", ["MemRef/IR/MemRefOps.td"]),
3539
("omp", "OpenMP.jl", ["OpenMP/OpenMPOps.td"]),
@@ -69,7 +73,16 @@ function mlir_dialects(version::VersionNumber)
6973
"Linalg.jl",
7074
["Linalg/IR/LinalgOps.td", "Linalg/IR/LinalgStructuredOps.td"],
7175
),
72-
("llvm", "LLVMIR.jl", ["LLVMIR/LLVMOps.td"]),
76+
(
77+
"llvm",
78+
"LLVMIR.jl",
79+
[
80+
"LLVMIR/LLVMOps.td",
81+
"LLVMIR/LLVMIntrinsicOps.td",
82+
"LLVMIR/NVVMOps.td",
83+
"LLVMIR/ROCDLOps.td",
84+
],
85+
),
7386
("math", "Math.jl", ["Math/IR/MathOps.td"]),
7487
("memref", "MemRef.jl", ["MemRef/IR/MemRefOps.td"]),
7588
("ml_program", "MLProgram.jl", ["MLProgram/IR/MLProgramOps.td"]),
@@ -121,7 +134,16 @@ function mlir_dialects(version::VersionNumber)
121134
"Linalg.jl",
122135
["Linalg/IR/LinalgOps.td", "Linalg/IR/LinalgStructuredOps.td"],
123136
),
124-
("llvm", "LLVMIR.jl", ["LLVMIR/LLVMOps.td"]),
137+
(
138+
"llvm",
139+
"LLVMIR.jl",
140+
[
141+
"LLVMIR/LLVMOps.td",
142+
"LLVMIR/LLVMIntrinsicOps.td",
143+
"LLVMIR/NVVMOps.td",
144+
"LLVMIR/ROCDLOps.td",
145+
],
146+
),
125147
("math", "Math.jl", ["Math/IR/MathOps.td"]),
126148
("memref", "MemRef.jl", ["MemRef/IR/MemRefOps.td"]),
127149
("ml_program", "MLProgram.jl", ["MLProgram/IR/MLProgramOps.td"]),
@@ -179,7 +201,16 @@ function mlir_dialects(version::VersionNumber)
179201
"Linalg.jl",
180202
["Linalg/IR/LinalgOps.td", "Linalg/IR/LinalgStructuredOps.td"],
181203
),
182-
("llvm", "LLVMIR.jl", ["LLVMIR/LLVMOps.td"]),
204+
(
205+
"llvm",
206+
"LLVMIR.jl",
207+
[
208+
"LLVMIR/LLVMOps.td",
209+
"LLVMIR/LLVMIntrinsicOps.td",
210+
"LLVMIR/NVVMOps.td",
211+
"LLVMIR/ROCDLOps.td",
212+
],
213+
),
183214
("math", "Math.jl", ["Math/IR/MathOps.td"]),
184215
("memref", "MemRef.jl", ["MemRef/IR/MemRefOps.td"]),
185216
("ml_program", "MLProgram.jl", ["MLProgram/IR/MLProgramOps.td"]),
@@ -215,6 +246,88 @@ function mlir_dialects(version::VersionNumber)
215246
("vector", "Vector.jl", ["Vector/IR/VectorOps.td"]),
216247
("x86vector", "X86Vector.jl", ["X86Vector/X86Vector.td"]),
217248
]
249+
elseif v"18" <= version < v"19"
250+
[
251+
("acc", "OpenACC.jl", ["OpenACC/OpenACCOps.td"]),
252+
("affine", "Affine.jl", ["Affine/IR/AffineOps.td"]),
253+
("amdgpu", "AMDGPU.jl", ["AMDGPU/IR/AMDGPU.td"]),
254+
("amx", "AMX.jl", ["AMX/AMX.td"]),
255+
("arith", "Arith.jl", ["Arith/IR/ArithOps.td"]),
256+
("arm_neon", "ArmNeon.jl", ["ArmNeon/ArmNeon.td"]),
257+
(
258+
"arm_sme",
259+
"ArmSME.jl",
260+
["ArmSME/IR/ArmSMEOps.td", "ArmSME/IR/ArmSMEIntrinsicOps.td"],
261+
),
262+
("arm_sve", "ArmSVE.jl", ["ArmSVE/IR/ArmSVE.td"]),
263+
("async", "Async.jl", ["Async/IR/AsyncOps.td"]),
264+
("bufferization", "Bufferization.jl", ["Bufferization/IR/BufferizationOps.td"]),
265+
("builtin", "Builtin.jl", ["../IR/BuiltinOps.td"]),
266+
("cf", "ControlFlow.jl", ["ControlFlow/IR/ControlFlowOps.td"]),
267+
("complex", "Complex.jl", ["Complex/IR/ComplexOps.td"]),
268+
# ("dlti", "DLTI.jl", ["DLTI/DLTI.td"]), # TODO crashes
269+
("emitc", "EmitC.jl", ["EmitC/IR/EmitC.td"]),
270+
("func", "Func.jl", ["Func/IR/FuncOps.td"]),
271+
("gpu", "GPU.jl", ["GPU/IR/GPUOps.td"]),
272+
("index", "Index.jl", ["Index/IR/IndexOps.td"]),
273+
("irdl", "IRDL.jl", ["IRDL/IR/IRDLOps.td"]),
274+
(
275+
"linalg",
276+
"Linalg.jl",
277+
["Linalg/IR/LinalgOps.td", "Linalg/IR/LinalgStructuredOps.td"],
278+
),
279+
(
280+
"llvm",
281+
"LLVMIR.jl",
282+
[
283+
"LLVMIR/LLVMOps.td",
284+
"LLVMIR/LLVMIntrinsicOps.td",
285+
"LLVMIR/NVVMOps.td",
286+
"LLVMIR/ROCDLOps.td",
287+
],
288+
),
289+
("math", "Math.jl", ["Math/IR/MathOps.td"]),
290+
("memref", "MemRef.jl", ["MemRef/IR/MemRefOps.td"]),
291+
("mesh", "Mesh.jl", ["Mesh/IR/MeshOps.td"]),
292+
("ml_program", "MLProgram.jl", ["MLProgram/IR/MLProgramOps.td"]),
293+
("nvgpu", "NVGPU.jl", ["NVGPU/IR/NVGPU.td"]),
294+
("omp", "OpenMP.jl", ["OpenMP/OpenMPOps.td"]),
295+
("pdl_interp", "PDLInterp.jl", ["PDLInterp/IR/PDLInterpOps.td"]),
296+
("pdl", "PDL.jl", ["PDL/IR/PDLOps.td"]),
297+
("quant", "Quant.jl", ["Quant/QuantOps.td"]),
298+
("scf", "SCF.jl", ["SCF/IR/SCFOps.td"]),
299+
("shape", "Shape.jl", ["Shape/IR/ShapeOps.td"]),
300+
("sparse_tensor", "SparseTensor.jl", ["SparseTensor/IR/SparseTensorOps.td"]),
301+
("spirv", "SPIRV.jl", ["SPIRV/IR/SPIRVOps.td"]),
302+
("tensor", "Tensor.jl", ["Tensor/IR/TensorOps.td"]),
303+
("tosa", "Tosa.jl", ["Tosa/IR/TosaOps.td"]),
304+
(
305+
"transform",
306+
"Transform.jl",
307+
[
308+
"Affine/TransformOps/AffineTransformOps.td",
309+
"Bufferization/TransformOps/BufferizationTransformOps.td",
310+
"Func/TransformOps/FuncTransformOps.td",
311+
"GPU/TransformOps/GPUTransformOps.td",
312+
"Linalg/TransformOps/LinalgMatchOps.td",
313+
"Linalg/TransformOps/LinalgTransformOps.td",
314+
"MemRef/TransformOps/MemRefTransformOps.td",
315+
"NVGPU/TransformOps/NVGPUTransformOps.td",
316+
"SCF/TransformOps/SCFTransformOps.td",
317+
"SparseTensor/TransformOps/SparseTensorTransformOps.td",
318+
"Tensor/TransformOps/TensorTransformOps.td",
319+
"Transform/IR/TransformOps.td",
320+
"Transform/DebugExtension/DebugExtensionOps.td",
321+
"Transform/LoopExtension/LoopExtensionOps.td",
322+
"Transform/PDLExtension/PDLExtensionOps.td",
323+
"Vector/TransformOps/VectorTransformOps.td",
324+
],
325+
),
326+
("ub", "UB.jl", ["UB/IR/UBOps.td"]),
327+
("vector", "Vector.jl", ["Vector/IR/VectorOps.td"]),
328+
("x86vector", "X86Vector.jl", ["X86Vector/X86Vector.td"]),
329+
]
330+
218331
else
219332
error("Unsupported MLIR version: $version")
220333
end
@@ -224,12 +337,13 @@ end
224337

225338
function rewrite!(dag::ExprDAG) end
226339

227-
julia_llvm = Dict([
228-
v"1.9" => v"14.0.5+3",
229-
v"1.10" => v"15.0.7+10",
230-
v"1.11" => v"16.0.6+2",
231-
v"1.12" => v"17.0.6+3",
232-
])
340+
julia_llvm = [
341+
(v"1.9", v"14.0.5+3"),
342+
(v"1.10", v"15.0.7+10"),
343+
(v"1.11", v"16.0.6+2"),
344+
(v"1.12", v"17.0.6+3"),
345+
(v"1.12", v"18.1.7+2"),
346+
]
233347
options = load_options(joinpath(@__DIR__, "wrap.toml"))
234348

235349
@add_def off_t

deps/tblgen/jl-generators.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <regex>
1818
#include <optional>
1919
#include <iostream>
20+
#include <numeric>
2021

2122
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/ADT/Sequence.h"
@@ -307,7 +308,7 @@ end
307308
for (int i = 0; i < op.getNumOperands(); i++)
308309
{
309310
const auto &named_operand = op.getOperand(i);
310-
std::string operandname = named_operand.name.str();
311+
std::string operandname = sanitizeName(named_operand.name.str());
311312
if (operandname.empty())
312313
{
313314
operandname = "operand_" + std::to_string(i);
@@ -317,8 +318,8 @@ end
317318
else
318319
opseglist.push_back(named_operand.isVariadic() ? "length(" + operandname + "), " : "1, ");
319320
}
320-
std::string operandsegmentsizes = std::accumulate(std::begin(x), std::end(x), string(),
321-
[](string &ss, string &s)
321+
std::string operandsegmentsizes = std::accumulate(std::begin(opseglist), std::end(opseglist), std::string(),
322+
[](std::string &ss, std::string &s)
322323
{
323324
return ss.empty() ? s : ss + "," + s;
324325
});

0 commit comments

Comments
 (0)