Skip to content

Commit 6ea4d11

Browse files
authored
Merge pull request #2 from llvm/main
Get latest changes
2 parents d40d74f + a265d28 commit 6ea4d11

File tree

46 files changed

+2090
-1409
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2090
-1409
lines changed

.github/workflows/pre-commit-all.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
runs-on: ubuntu-22.04
1010
steps:
1111
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
12-
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
12+
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
1313
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
1414
with:
1515
extra_args: --color=always --all-files

.github/workflows/pre-commit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
with:
1212
# requites to grab the history of the PR
1313
fetch-depth: 0
14-
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
14+
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
1515
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
1616
with:
1717
extra_args: --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@ option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directo
3535
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON)
3636

3737
option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON)
38+
3839
if(TORCH_MLIR_ENABLE_REFBACKEND)
3940
add_definitions(-DTORCH_MLIR_ENABLE_REFBACKEND)
4041
endif()
4142

43+
set(TORCH_MLIR_TABLEGEN_FLAGS "")
44+
4245
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
4346
if(TORCH_MLIR_ENABLE_STABLEHLO)
4447
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
48+
list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_STABLEHLO")
4549
endif()
4650
# It is possible that both stablehlo and torch_mlir projects are used in some compiler project.
4751
# In this case, we don't want to use stablehlo that is downloaded by torch_mlir (in external/stablehlo)
@@ -50,6 +54,12 @@ endif()
5054
# stablehlo targets AND includes available (for example with `add_subdirectory` and `include_directories`).
5155
option(TORCH_MLIR_USE_EXTERNAL_STABLEHLO "Use stablehlo from top level project" OFF)
5256

57+
option(TORCH_MLIR_ENABLE_TOSA "Add TOSA support" ON)
58+
if(TORCH_MLIR_ENABLE_TOSA)
59+
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
60+
list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_TOSA")
61+
endif()
62+
5363
option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)
5464

5565
# PyTorch native extension gate. If OFF, then no features which depend on

build_tools/autogen_ltc_backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
3131
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
3232

33+
# Safely load fast C Yaml loader if it is are available
34+
try:
35+
from yaml import CSafeLoader as Loader
36+
except ImportError:
37+
from yaml import SafeLoader as Loader # type:ignore[assignment, misc]
38+
3339

3440
def reindent(text, prefix=""):
3541
return indent(dedent(text), prefix)
@@ -175,7 +181,7 @@ def generate_native_functions(self):
175181
)
176182
ts_native_yaml = None
177183
if ts_native_yaml_path.exists():
178-
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader)
184+
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), Loader)
179185
else:
180186
logging.warning(
181187
f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}"
@@ -208,7 +214,7 @@ def get_opnames(ops):
208214
)
209215

210216
with self.config_path.open() as f:
211-
config = yaml.load(f, yaml.CLoader)
217+
config = yaml.load(f, Loader)
212218

213219
# List of unsupported ops in LTC autogen because of some error
214220
blacklist = set(config.get("blacklist", []))

externals/llvm-project

Submodule llvm-project updated 5542 files

include/torch-mlir-c/TorchTypes.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
260260
MLIR_CAPI_EXPORTED MlirType
261261
torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
262262

263-
/// Gets the the rank (number of dimensions) of a !torch.tensor
263+
/// Gets the rank (number of dimensions) of a !torch.tensor
264264
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t);
265265

266266
/// Return true if this type has a list of sizes.
@@ -269,12 +269,12 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t);
269269
/// Return true if this type has a dtype.
270270
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
271271

272-
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
272+
/// Gets the sizes of the dimensions of a !torch.tensor; note -1 size
273273
/// indicates an unrefined/unknown size dimension.
274274
MLIR_CAPI_EXPORTED int64_t
275275
torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
276276

277-
/// Gets the the dtype (data type) of a !torch.tensor.
277+
/// Gets the dtype (data type) of a !torch.tensor.
278278
MLIR_CAPI_EXPORTED MlirType
279279
torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
280280

@@ -307,7 +307,7 @@ torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
307307
MLIR_CAPI_EXPORTED MlirType
308308
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
309309

310-
/// Gets the the rank (number of dimensions) of a !torch.vtensor
310+
/// Gets the rank (number of dimensions) of a !torch.vtensor
311311
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t);
312312

313313
/// Return true if this type has a list of sizes.
@@ -316,12 +316,12 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasSizes(MlirType t);
316316
/// Return true if this type has a dtype.
317317
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
318318

319-
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
319+
/// Gets the sizes of the dimensions of a !torch.vtensor; note -1 size
320320
/// indicates an unrefined/unknown size dimension.
321321
MLIR_CAPI_EXPORTED int64_t
322322
torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
323323

324-
/// Gets the the dtype (data type) of a !torch.vtensor.
324+
/// Gets the dtype (data type) of a !torch.vtensor.
325325
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
326326

327327
/// Gets the !torch.vtensor typeid.
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
add_subdirectory(TorchOnnxToTorch)
22

33
set(LLVM_TARGET_DEFINITIONS Passes.td)
4-
if(TORCH_MLIR_ENABLE_STABLEHLO)
5-
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
6-
else()
7-
mlir_tablegen(Passes.h.inc -gen-pass-decls)
8-
endif()
4+
5+
6+
7+
mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS})
8+
99
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
1010

1111
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)

include/torch-mlir/Conversion/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> {
114114
let constructor = "mlir::torch::createConvertTorchToTensorPass()";
115115
}
116116

117+
#ifdef TORCH_MLIR_ENABLE_TOSA
117118
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
118119
let summary = "Convert Torch ops to TOSA ops";
119120
let description = [{
@@ -122,6 +123,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
122123
}];
123124
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
124125
}
126+
#endif
125127

126128
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
127129
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
2424
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
2525
ArrayRef<int64_t> indexShape);
2626

27+
// Default function to create TOSA op with shift value
2728
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
2829
TensorType outType, Value lhs, Value rhs,
2930
int32_t shift);
@@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
3233
template <typename TosaOpT>
3334
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
3435
TensorType outType, Value lhs, Value rhs) {
35-
lhs = promoteType(rewriter, lhs, outType);
36-
rhs = promoteType(rewriter, rhs, outType);
36+
lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value();
37+
rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value();
3738
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
3839
}
3940

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
1111
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
1212

13-
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
14-
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
15-
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
16-
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
17-
#include "mlir/IR/PatternMatch.h" // from @llvm-project
18-
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
19-
#include "mlir/Support/LLVM.h" // from @llvm-project
13+
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
14+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
15+
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
16+
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
17+
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
18+
#include "mlir/IR/PatternMatch.h" // from @llvm-project
19+
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
20+
#include "mlir/Support/LLVM.h" // from @llvm-project
2021

2122
namespace mlir {
2223
namespace tosa {
@@ -45,6 +46,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4546
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
4647
float val);
4748

49+
// Create an int8_t const tosa.mul shift tensor from an int
50+
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
51+
int32_t shift);
52+
4853
// Create a zero constant tensor of the desired type and shape.
4954
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
5055
Operation *op, Type type);
@@ -58,55 +63,24 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
5863
ArrayRef<T> vec, ArrayRef<int64_t> shape,
5964
std::optional<Type> dtype = {});
6065

61-
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
62-
Value src, Type destType, Value &result);
63-
64-
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
66+
// Default function to create tosa.cast op. This should be called instead of
67+
// directly calling rewriter.create<tosa::CastOp>.
68+
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
69+
TensorType destType);
6570

6671
// Creates a TOSA operation and performs shape inference on the individual
6772
// op. This allows shape inference during the framework to TOSA lowering.
73+
template <typename TosaOp, typename... Args>
74+
TosaOp CreateOpAndInfer(ImplicitLocOpBuilder &builder, Type result_ty,
75+
Args &&...args) {
76+
return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
77+
}
78+
6879
template <typename TosaOp, typename... Args>
6980
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
7081
Args &&...args) {
71-
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
72-
73-
InferShapedTypeOpInterface shapeInterface =
74-
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
75-
if (!shapeInterface)
76-
return op;
77-
78-
SmallVector<ShapedTypeComponents> returnedShapes;
79-
if (shapeInterface
80-
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
81-
op->getOperands(), op->getAttrDictionary(),
82-
op->getPropertiesStorage(),
83-
op->getRegions(), returnedShapes)
84-
.failed())
85-
return op;
86-
87-
// We need to use the element type of the existing result type to generate
88-
// the new result shaped type. This is because rescale can include a cast to
89-
// different bit-width types and does not have a TypeAttr to define the
90-
// target type.
91-
auto result = op->getResult(0);
92-
auto predictedShape = returnedShapes[0];
93-
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
94-
95-
// Compute the knowledge based on the inferred type.
96-
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
97-
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
98-
inferredKnowledge.hasRank = predictedShape.hasRank();
99-
if (predictedShape.hasRank()) {
100-
for (auto dim : predictedShape.getDims()) {
101-
inferredKnowledge.sizes.push_back(dim);
102-
}
103-
}
104-
105-
// Compute the new type based on the joined version.
106-
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
107-
auto new_ty = newKnowledge.getType();
108-
result.setType(new_ty);
109-
return op;
82+
ImplicitLocOpBuilder builder(loc, rewriter);
83+
return CreateOpAndInfer<TosaOp>(builder, result_ty, args...);
11084
}
11185

11286
template <typename TosaOp, typename... Args>

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5236,6 +5236,7 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
52365236
printDefaultTorchOp(printer, *this, 2, 1);
52375237
}
52385238
}];
5239+
let hasCanonicalizer = 1;
52395240
}
52405241

52415242
def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
@@ -5260,6 +5261,7 @@ def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
52605261
printDefaultTorchOp(printer, *this, 2, 1);
52615262
}
52625263
}];
5264+
let hasCanonicalizer = 1;
52635265
}
52645266

52655267
def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [
@@ -13704,7 +13706,7 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
1370413706
HasValueSemantics,
1370513707
ReadOnly
1370613708
]> {
13707-
let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)`";
13709+
let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?, bool?) -> (Tensor)`";
1370813710
let arguments = (ins
1370913711
AnyTorchTensorType:$self,
1371013712
Torch_IntType:$n_fft,
@@ -13713,18 +13715,19 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
1371313715
AnyTorchOptionalTensorType:$window,
1371413716
Torch_BoolType:$normalized,
1371513717
AnyTorchOptionalBoolType:$onesided,
13716-
AnyTorchOptionalBoolType:$return_complex
13718+
AnyTorchOptionalBoolType:$return_complex,
13719+
AnyTorchOptionalBoolType:$align_to_window
1371713720
);
1371813721
let results = (outs
1371913722
AnyTorchOptionalTensorType:$result
1372013723
);
1372113724
let hasCustomAssemblyFormat = 1;
1372213725
let extraClassDefinition = [{
1372313726
ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) {
13724-
return parseDefaultTorchOp(parser, result, 8, 1);
13727+
return parseDefaultTorchOp(parser, result, 9, 1);
1372513728
}
1372613729
void AtenStftOp::print(OpAsmPrinter &printer) {
13727-
printDefaultTorchOp(printer, *this, 8, 1);
13730+
printDefaultTorchOp(printer, *this, 9, 1);
1372813731
}
1372913732
}];
1373013733
}
@@ -15175,6 +15178,31 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [
1517515178
let hasFolder = 1;
1517615179
}
1517715180

15181+
def Torch_AtenArgsortOp : Torch_Op<"aten.argsort", [
15182+
AllowsTypeRefinement,
15183+
HasValueSemantics,
15184+
ReadOnly
15185+
]> {
15186+
let summary = "Generated op for `aten::argsort : (Tensor, int, bool) -> (Tensor)`";
15187+
let arguments = (ins
15188+
AnyTorchTensorType:$self,
15189+
Torch_IntType:$dim,
15190+
Torch_BoolType:$descending
15191+
);
15192+
let results = (outs
15193+
AnyTorchOptionalTensorType:$result
15194+
);
15195+
let hasCustomAssemblyFormat = 1;
15196+
let extraClassDefinition = [{
15197+
ParseResult AtenArgsortOp::parse(OpAsmParser &parser, OperationState &result) {
15198+
return parseDefaultTorchOp(parser, result, 3, 1);
15199+
}
15200+
void AtenArgsortOp::print(OpAsmPrinter &printer) {
15201+
printDefaultTorchOp(printer, *this, 3, 1);
15202+
}
15203+
}];
15204+
}
15205+
1517815206
def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [
1517915207
AllowsTypeRefinement,
1518015208
ReadOnly
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
set(LLVM_TARGET_DEFINITIONS Passes.td)
2-
if(TORCH_MLIR_ENABLE_STABLEHLO)
3-
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
4-
else()
5-
mlir_tablegen(Passes.h.inc -gen-pass-decls)
6-
endif()
2+
3+
mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS})
4+
75
add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen)
86

97
add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc)

0 commit comments

Comments
 (0)