Skip to content

Commit

Permalink
Updates conversion of torch.expand to ttnn.repeat to allow repeating …
Browse files Browse the repository at this point in the history
…on (#803)

last dimension.

This has been fixed in tt-metal, so we can now use tensor expand on the
last dimension. This also allows us to run mobilenet_ssd end to end.

closes 436
  • Loading branch information
jmalone-tt authored Mar 3, 2025
1 parent 8c78b36 commit 0c5d9c6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/models/mobilenet_ssd/test_mobilenet_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _load_inputs(self):
"mode",
["eval"],
)
@pytest.mark.converted_end_to_end
def test_mobilenet_ssd(record_property, mode):
model_name = "MobileNetSSD"
record_property("model_name", model_name)
Expand Down
3 changes: 1 addition & 2 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,7 @@ def reshape_1d(code, args=args, kwargs=kwargs):
# aten.expand and ttnn.repeat has different meaning for their `shape` argument
# aten.expand: the desired output shape, where respective singleton dims are broadcasted
# ttnn.repeat: the number of times to repeat a respective singleton dim
# Repeat fails if last dimension of input is 1
if input_tensor_shape[-1] != 1 and len(input_tensor_shape) == len(output_shape):
if len(input_tensor_shape) == len(output_shape):
return g.call_function(target_wrappers.repeat, args=(args[0], multiplier.tolist()))

return None
Expand Down

0 comments on commit 0c5d9c6

Please sign in to comment.