Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed to legalize operation 'hal.interface.constant.load' #18487

Open
pdhirajkumarprasad opened this issue Sep 11, 2024 · 5 comments
Open

failed to legalize operation 'hal.interface.constant.load' #18487

pdhirajkumarprasad opened this issue Sep 11, 2024 · 5 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@pdhirajkumarprasad
Copy link

What happened?

for the given IR

module {
  func.func @main_graph(%arg0: !torch.vtensor<[1,128],si64>, %arg1: !torch.vtensor<[1,128],f32>, %arg2: !torch.vtensor<[?,?,?,?],f32>, %arg3: !torch.vtensor<[127,127],f32>, %arg4:!torch.vtensor<[2,2],si64>, %arg5: !torch.vtensor<[1,64,12,64],f32>, %arg6: !torch.vtensor<[1,128,12,64],f32>, %arg7: !torch.vtensor<[1],si64>, %arg8: !torch.vtensor<[2],si64> ) -> !torch.vtensor<[?,?,?,?],f32>  attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %none = torch.constant.none
    %416 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__72> : tensor<7x7xf32>} : () -> !torch.vtensor<[7,7],f32> 
    %418 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__74> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %419 = torch.operator "onnx.ConstantOfShape"(%arg7) {torch.onnx.value = dense_resource<__75> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[0],si64> 
    %420 = torch.operator "onnx.Concat"(%418, %419) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[4],si64> 
    %422 = torch.operator "onnx.Reshape"(%420, %arg8) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,2],si64> 
    %423 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__77> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %427 = torch.operator "onnx.Slice"(%422, %423, %423, %423, %423) : (!torch.vtensor<[2,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,2],si64> 
    %428 = torch.operator "onnx.Transpose"(%427) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[2,2],si64>) -> !torch.vtensor<[2,2],si64> 
    %429 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__81> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %430 = torch.operator "onnx.Reshape"(%428, %429) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %431 = torch.operator "onnx.Cast"(%430) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> 
    %432 = torch.operator "onnx.Pad"(%416, %431, %none) {torch.onnx.mode = "constant"} : (!torch.vtensor<[7,7],f32>, !torch.vtensor<[4],si64>, !torch.none) -> !torch.vtensor<[?,?],f32> 
    %957 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__273> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %959 = torch.operator "onnx.Slice"(%432, %957, %957, %957, %957) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?],f32> 
    %960 = torch.operator "onnx.Concat"(%959, %432) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> 
    %963 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__277> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %965 = torch.operator "onnx.Slice"(%960, %963, %963, %963, %963) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?],f32> 
    %995 = torch.operator "onnx.Einsum"(%arg5, %arg6) {torch.onnx.equation = "bind,bjnd->bnij"} : (!torch.vtensor<[1,64,12,64],f32>, !torch.vtensor<[1,128,12,64],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %1043 = torch.operator "onnx.Mul"(%arg2, %965) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %1076 = torch.operator "onnx.Add"(%995, %1043) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    return %1076: !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      __72: "0x080000000000803F0000803F0000803F0000803F000080",
      __74: "0x080000000100000000000000000000000000000001000000000000000000000000000000",
      __75: "0x080000000000000000000000",
      __77: "0x080000000000000000000000",
      __81: "0x08000000FFFFFFFFFFFFFFFF",
      __273: "0x080000000100000000000000",
      __277: "0x08000000FFFFFFFFFFFFFFFF"
    }
  }
#-}

Getting following error:

model.mlir:22:13: error: failed to legalize operation 'hal.interface.constant.load'
    %1043 = torch.operator "onnx.Mul"(%arg2, %965) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
            ^
model.mlir:22:13: note: see current operation: %84 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32

dump with flag : '--mlir-print-ir-after-all --mlir-print-ir-before-all --mlir-disable-threading --mlir-elide-elementsattrs-if-larger=4' (due to file size restriction, initial part has been removed from log)
dump.log

Steps to reproduce your issue

command:

iree-compile --iree-hal-target-backends=llvm-cpu model.mlir

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

@nirvedhmeshram
Copy link
Contributor

@pdhirajkumarprasad I am not able to get the error mentioned in the issue, for me this is failing in convert-torch-onnx-to-torch the code for which lives in torch-mlir , here is the crash dump

@pdhirajkumarprasad
Copy link
Author

pdhirajkumarprasad commented Sep 27, 2024

@nirvedhmeshram we are seeing multiple crash. original issue is masked due to nod-ai/SHARK-ModelDev#852

@vinayakdsci
Copy link
Contributor

@pdhirajkumarprasad I am not able to get the error mentioned in the issue, for me this is failing in convert-torch-onnx-to-torch the code for which lives in torch-mlir , here is the crash dump

I have the IR failing with the same crash, @nirvedhmeshram. This is happening because the constant __72 does not have the right buffer size, that matches 7x7 elements of bitwidth 32.

@pdhirajkumarprasad is there a specific model that produces this IR? We could be looking at a possible import issue.

@vinayakdsci
Copy link
Contributor

@pdhirajkumarprasad Unable to reproduce the issue with any of the models mentioned in nod-ai/SHARK-ModelDev#812 on the respective tracker, with latest build of IREE.

@pdhirajkumarprasad
Copy link
Author

Currently I am seeing crash for following simple IR

module {
  func.func @main_graph( %arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[2],si64> ) -> !torch.vtensor<[2,2],si64>   attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__74> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %2 = torch.operator "onnx.ConstantOfShape"(%arg0) {torch.onnx.value = dense_resource<__75> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[0],si64> 
    %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[4],si64> 
    %4 = torch.operator "onnx.Reshape"(%3, %arg1) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,2],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__77> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %6 = torch.operator "onnx.Slice"(%4, %5, %5, %5, %5) : (!torch.vtensor<[2,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,2],si64> 
    return %4: !torch.vtensor<[2,2],si64>
  }
}

{-#
  dialect_resources: {
    builtin: {
      __74: "0x080000000100000000000000000000000000000001000000000000000000000000000000",
      __75: "0x080000000000000000000000",
      __77: "0x080000000000000000000000"
    }
  }
#-}

crash log:

Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.	Program arguments: my_env_hf/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile --iree-hal-target-backends=llvm-cpu t.mlir
 #0 0x00007f0086687ac8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x5680ac8)
 #1 0x00007f008668586e llvm::sys::RunSignalHandlers() (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x567e86e)
 #2 0x00007f0086688176 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
 #3 0x00007f0080e21520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f008787bd5b auto mlir::torch::Torch::AtenSliceTensorOp::fold(mlir::torch::Torch::AtenSliceTensorOpGenericAdaptor<llvm::ArrayRef<mlir::Attribute> >)::$_0::operator()<$_0>($_0&, long, long) const TorchOps.cpp:0:0
 #5 0x00007f008787bab1 mlir::torch::Torch::AtenSliceTensorOp::fold(mlir::torch::Torch::AtenSliceTensorOpGenericAdaptor<llvm::ArrayRef<mlir::Attribute> >) TorchOps.cpp:0:0
 #6 0x00007f00877cccd0 llvm::LogicalResult mlir::Op<mlir::torch::Torch::AtenSliceTensorOp, mlir::OpTrait::ZeroRegions, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult<mlir::Type>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::NOperands<5u>::Impl, mlir::OpTrait::OpInvariants, mlir::torch::Torch::OpTrait::AllowsTypeRefinement, mlir::torch::Torch::OpTrait::ReadOnly>::foldSingleResultHook<mlir::torch::Torch::AtenSliceTensorOp>(mlir::Operation*, llvm::ArrayRef<mlir::Attribute>, llvm::SmallVectorImpl<mlir::OpFoldResult>&) TorchDialect.cpp:0:0
 #7 0x00007f00877cc655 mlir::RegisteredOperationName::Model<mlir::torch::Torch::AtenSliceTensorOp>::foldHook(mlir::Operation*, llvm::ArrayRef<mlir::Attribute>, llvm::SmallVectorImpl<mlir::OpFoldResult>&) TorchDialect.cpp:0:0
 #8 0x00007f00867671cf mlir::Operation::fold(llvm::ArrayRef<mlir::Attribute>, llvm::SmallVectorImpl<mlir::OpFoldResult>&) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x57601cf)
 #9 0x00007f008676745b mlir::Operation::fold(llvm::SmallVectorImpl<mlir::OpFoldResult>&) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x576045b)
#10 0x00007f008b093fcf (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist() GreedyPatternRewriteDriver.cpp:0:0
#11 0x00007f008b09234b mlir::applyPatternsGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0xa08b34b)
#12 0x00007f008742aa7d (anonymous namespace)::DecomposeComplexOpsPass::runOnOperation() DecomposeComplexOps.cpp:0:0
#13 0x00007f00868cc2bb mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c52bb)
#14 0x00007f00868ccea9 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c5ea9)
#15 0x00007f00868cec38 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c7c38)
#16 0x00007f00868cc775 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c5775)
#17 0x00007f00868ccea9 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c5ea9)
#18 0x00007f00868d1fca llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_0>(long, mlir::OpPassManager&, mlir::Operation*) Pass.cpp:0:0
#19 0x00007f0087d3809c mlir::iree_compiler::InputConversion::(anonymous namespace)::AutoInputConversionPipelinePass::runOnOperation() AutoInputConversionPipeline.cpp:0:0
#20 0x00007f00868cc2bb mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c52bb)
#21 0x00007f00868d01cb mlir::PassManager::run(mlir::Operation*) (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x58c91cb)
#22 0x00007f00865da942 ireeCompilerInvocationPipeline (my_env_hf/lib/python3.10/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so+0x55d3942)
#23 0x00007f00867fd29d mlir::iree_compiler::runIreecMain(int, char**)::$_2::operator()(iree_compiler_source_t*) const iree_compile_lib.cc:0:0
#24 0x00007f00867fcb00 mlir::iree_compiler::runIreecMain(int, char**) iree_compile_lib.cc:0:0
#25 0x00007f0080e08d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#26 0x00007f0080e08e40 call_init ./csu/../csu/libc-start.c:128:20
#27 0x00007f0080e08e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#28 0x00000000002016ae _start (my_env_hf/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile+0x2016ae)

command : iree-compile t.mlir --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host

iree version: IREE compiler version 3.3.0rc20250303 @ cc37664

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
Status: No status
Development

No branches or pull requests

3 participants