Skip to content

Commit 4a27c75

Browse files
committed
Revert "Add ttnn.squeeze for lowering embedding when the rank of input is 1"
This reverts commit 39b81e7.
1 parent 7200411 commit 4a27c75

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

tests/lowering/embedding/test_embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch_ttnn
33
import pytest
44
import ttnn
5-
from tests.utils import assert_with_pcc
65

76

87
class EmbeddingModule(torch.nn.Module):
@@ -24,7 +23,7 @@ def forward(self, input, weights):
2423

2524
@pytest.mark.parametrize(
2625
"input_shapes",
27-
[[((1, 2, 4, 5), (4, 3, 2, 9)), (10, 4)], [((0, 1, 2, 3)), (4, 2)]],
26+
[[((1, 2, 4, 5), (4, 3, 2, 9)), (10, 4)]],
2827
)
2928
def test_embedding(device, input_shapes):
3029
m = EmbeddingModule()
@@ -42,7 +41,7 @@ def test_embedding(device, input_shapes):
4241
nodes = list(option._out_fx_graphs[0].nodes)
4342
assert [node.target for node in nodes].count(ttnn.embedding) == 1
4443
# Check inference result
45-
assert_with_pcc(result_before, result_after)
44+
assert torch.allclose(result_before, result_after)
4645

4746

4847
@pytest.mark.parametrize(

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,13 +558,10 @@ def reshape_1d(code, args=args, kwargs=kwargs):
558558
return g.call_function(ttnn.add, args=(beta_node, new_node))
559559

560560
if node.target == torch.ops.aten.embedding.default:
561-
tensor_meta = args[1].meta.get("val")
562-
tiled = False if tensor_meta is None else tensor_meta.size()[-1] % ttnn.TILE_SIZE == 0
561+
tiled = args[1].meta.get("val")
562+
tiled = False if tiled is None else tiled.size()[-1] % ttnn.TILE_SIZE == 0
563563
layout = TtnnTileLayout() if tiled else TtnnRowMajorLayout()
564564
tensor = g.call_function(ttnn.embedding, (args[1], args[0]), {"layout": layout})
565-
# TODO: Remove this squeeze when issue is fixed: https://github.com/tenstorrent/pytorch2.0_ttnn/issues/660
566-
if len(tensor_meta.size()) == 1:
567-
tensor = g.call_function(ttnn.squeeze, (tensor, 0))
568565
return tensor if tiled else g.call_function(ttnn.to_layout, (tensor, TtnnTileLayout()))
569566

570567
if node.target == torch.ops.aten._log_softmax.default:

0 commit comments

Comments
 (0)