Skip to content

Commit

Permalink
Fix and add workaround for falcon-7b model test (#719)
Browse files Browse the repository at this point in the history
* Add fallback to arange and argmax variations for falcon-7b

* Append argmax blocklist instead since it already exists

* Add arange.start_step to constantfolding pass instead

* Fix arange test

* Correct error message
  • Loading branch information
kevinwuTT authored Jan 22, 2025
1 parent 4d53823 commit 8a71747
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/lowering/creation/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ def test_arange_start_step(device, input_shapes):

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.arange) == 1
assert [node.target for node in nodes].count(ttnn.arange) == 1 or [node.op for node in nodes].count("get_attr")
# Check inference result
assert torch.allclose(result_before, result_after)
1 change: 1 addition & 0 deletions torch_ttnn/passes/constant_folding_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self):
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten.pow.Tensor_Tensor,
torch.ops.aten.arange.start,
torch.ops.aten.arange.start_step,
torch.ops.aten.unsqueeze.default,
torch.ops.aten.arange.default,
torch.ops.aten.view.default,
Expand Down
7 changes: 7 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@
["List[Tensor] tensors = [<[13600]>, <[13600]>, <[13600]>, <[13600]>]", "int dim = 1"],
]

############################################################
# EXTRA BLOCKLIST OF falcon-7b-instruct
############################################################
# RuntimeError: TT_THROW @ /tmp/build-via-sdist-d26xvola/ttnn-0.54.0rc18+wormhole.b0/ttnn/cpp/ttnn/device_operation.hpp:487: tt::exception
# Unsupported storage type
aten_argmax_default_blocklist += [["Tensor<[1, 7]> self = ?", "Optional[int] dim = 1", "bool keepdim = True"]]

############################################################
# EXTRA BLOCKLIST OF retinanet_resnet50_fpn_v2
############################################################
Expand Down

0 comments on commit 8a71747

Please sign in to comment.