Skip to content

Incompatible cast error for generated mlir of Flux model #2652

@prosenjitdhole

Description

@prosenjitdhole

While executing the test sharktank/tests/models/flux/flux_test.py::FluxTest::testCompareToyIreeF32AgainstEagerF64, we are generating a wrong MLIR which is causing compiler error :

E               iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
E               Error code: 1
E               Diagnostics:
E               /tmp/tmpc28aze5uFluxTest/model.mlir:544:12: error: 'tensor.cast' op operand type 'tensor<1x17x4xf32>' and result type 'tensor<1x17x4xf64>' are cast incompatible
E                   %202 = torch.aten.einsum %str, %201, %none_102 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[1,17,4],f64>

Generated mlir is attached : test.mlir.txt

Command to run attahced MLIR :

iree-compile test.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary -o=model.vmfb --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-execution-model=async-external --iree-global-opt-propagate-transposes=1 --iree-opt-const-eval=0 --iree-opt-outer-dim-concat=1 --iree-opt-aggressively-propagate-transposes=1 --iree-codegen-llvmgpu-use-vector-distribution=1 --iree-llvmgpu-enable-prefetch=1 --iree-opt-data-tiling=0 --iree-vm-target-truncate-unsupported-floats --iree-dispatch-creation-enable-aggressive-fusion --iree-hal-memoization=1 --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=1 '--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize),iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-dispatch-creation-bubble-up-expand-shapes, canonicalize, cse, canonicalize), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-cpu=host

Steps to reproduce the issue :

pytest -n 20 sharktank/tests/models/flux/flux_test.py::FluxTest::testCompareToyIreeF32AgainstEagerF64 --cov=sharktank --cov-report xml:cov.xml --cov-config=.coveragerc --durations=10 --log-cli-level=info -v

Notes :

  1. Happening on integrate/iree branch
  2. iree is pinned to

IREE (https://iree.dev):
IREE compiler version 3.9.0rc20251105 @ c65dc6dc28491b4768a72ff9d563edeb377627a9
LLVM version 22.0.0git
Optimized build

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions