Skip to content

Commit 8a71747

Browse files
authored
Fix and add workaround for falcon-7b model test (#719)
* Add fallback to arange and argmax variations for falcon-7b * Append argmax blocklist instead since it already exists * Add arange.start_step to constantfolding pass instead * Fix arange test * Correct error message
1 parent 4d53823 commit 8a71747

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

tests/lowering/creation/test_arange.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,6 @@ def test_arange_start_step(device, input_shapes):
9090

9191
# Check the graph has be rewritten and contain ttnn ops
9292
nodes = list(option._out_fx_graphs[0].nodes)
93-
assert [node.target for node in nodes].count(ttnn.arange) == 1
93+
assert [node.target for node in nodes].count(ttnn.arange) == 1 or [node.op for node in nodes].count("get_attr")
9494
# Check inference result
9595
assert torch.allclose(result_before, result_after)

torch_ttnn/passes/constant_folding_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def __init__(self):
1111
torch.ops.aten.lift_fresh_copy.default,
1212
torch.ops.aten.pow.Tensor_Tensor,
1313
torch.ops.aten.arange.start,
14+
torch.ops.aten.arange.start_step,
1415
torch.ops.aten.unsqueeze.default,
1516
torch.ops.aten.arange.default,
1617
torch.ops.aten.view.default,

torch_ttnn/passes/lowering/to_tt_guard.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@
8383
["List[Tensor] tensors = [<[13600]>, <[13600]>, <[13600]>, <[13600]>]", "int dim = 1"],
8484
]
8585

86+
############################################################
87+
# EXTRA BLOCKLIST OF falcon-7b-instruct
88+
############################################################
89+
# RuntimeError: TT_THROW @ /tmp/build-via-sdist-d26xvola/ttnn-0.54.0rc18+wormhole.b0/ttnn/cpp/ttnn/device_operation.hpp:487: tt::exception
90+
# Unsupported storage type
91+
aten_argmax_default_blocklist += [["Tensor<[1, 7]> self = ?", "Optional[int] dim = 1", "bool keepdim = True"]]
92+
8693
############################################################
8794
# EXTRA BLOCKLIST OF retinanet_resnet50_fpn_v2
8895
############################################################

0 commit comments

Comments
 (0)