Skip to content

Commit 4f8daaf

Browse files
authored
Add conversion for more OPs (#11)
* Use TILE_LAYOUT during data move-in * Insert a ttnn.to_layout(ttnn.TILE_LAYOUT) between ttnn.from_torch and ttnn.to_device when adding data move-in ops * ttnn.reshape will skip inserting ttnn.to_layout * Update the tests to reflect newly inserted function * Fix reshape * Add conversion from torch.relu and torch.addmm to ttnn * Add conversion from torch.div, torch.bmm, and torch.gelu to ttnn * Add workaround to handle input aliasing * Add conversion from aten.rsub and aten.embedding * Add conversion from aten.split * Move GraphCleanup method to a new file * Move Dummy string repr to separate utils file * Fix rsub elif * Add torch.clone conversion to ttnn.clone ttnn.clone requires extra arguments compared to torch.clone: MemoryConfig type and output clone type * Construct ttnn.MemoryConfig for DRAM * Retrieve metadata from original torch op and translate to ttnn type * Add support for kwargs * Add conversion from torch.nn.functional.layer_norm torch.nn.LayerNorm does not have parameters for custom weights and bias and produces values that differ quite a bit from ttnn.layer_norm. This is not supporterd yet. However, torch.nn.functional.layer_norm can produce values that are very close to ttnn.layer_norm and this commit will test against the aten op that is lowered by this higher level torch op. aten.native_layer_norm returns 3 outputs: layer norm, mean(?), rstd(?). However, torch.nn.functional.layer_norm only cares about the layer norm output. Currently, this transformation replaces the mean and rstd with layer norm output. This should be be fixed later. * Add conversion from torch.neg, torch.ones, and torch.tril to ttnn counterparts * ttnn.ones require passing the device object manually * A default device has to be set up for AutoFormat, since ttnn.tril uses it * Use custom class for kwarg object instead of a generic tuple. * Add transformation aten.{eq.Tensor, eq.Scalar, logical_not, zeros_like, mean.dim} * Fix torch.compile options for other tests * Move transformations to ttnn.add and ttnn.mul to ToTtPass This requires a patch to ttnn.decorators.Operation * Fix test_fall_back * Add transformations for several more ops * Pow (Tensor base, scalar exponent) * Rsqrt * Silu * Adaptive Avg Pool * Clamp * Squeeze (dim argument) * Fix transformations for torch.eq and add transformation for torch.full torch.eq (scalar) -> ttnn.full + ttnn.eq (tensor) Previously ttnn.eq supports a scalar argument, but this errors now. * Disable torch to ttnn.split test since fallback is disabled and op is not implemented yet * Update torch to ttnn.reshape tests to match some limitations * Implement transformation for torch.lq.{scalar,tensor} and generalize relational ops for cleaner implementation * Implement aten.baddbmm transformation to ttnn * Add transformation from torch.cos * Remove conversion for split because ttnn.split is removed entirely See: #5389 * Add transformation for torch.sigmoid * Cast all model input arguments to bfloat16 * Set aten.view to fallback. Need to handle restrictions from ttnn.reshape * Fix layer_norm conversion to handle cases where ttnn ops follow * Handle case where aten.full has an empty shape * Remove split conversion from to_tt pass * Add fallback to squeeze conversion since ttnn.squeeze only supports dim 0. * Add aten.rsub.Scalar conversion * Match restrictions from ttnn.arange * Add workaround for relational op conversion for certain input sizes * Restrict embedding conversion to only support TILE_LAYOUT for now * Handle case where the denominator is a scalar for div op * Add workaround for when the model output takes the output from argmax. aten.argmax outputs integer values, but ttnn.argmax outputs floating point * Remove extraneous prints * Add bert and falcon-7b models for testing with torch_stat backend * These models have "/" in the names. Small fix in torch_stat backend. * Update AutoModelForCausalLM models and add bigscience/bloom-1b1 model * Add mamba, llama, gpt2, and yolos models * Fix e2e run with torch_ttnn backend * Add option to select a model * Remove dependency on tests module from tt-metal and copy relevant test utility functions to this repo * Update readme to include instructions on running transformer model with ttnn backend * Run black formatter on files * Fix formatting for run_transformers * Install latest prerelease metal_libs wormhole wheel * Make sure to use bfloat when using full * Add wheel for metal_libs to requirements and fix compatibility versions * Increase verbosity when running tests in the CI * TT_METAL_HOME and PYTHONPATH env variables are not needed if installing from wheel * Revert "Install latest prerelease metal_libs wormhole wheel" This reverts commit be91195. * Update metal libs wheel
1 parent f8087ae commit 4f8daaf

22 files changed

+3025
-240
lines changed

.github/workflows/before_merge.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ jobs:
1515
validate-pr:
1616
env:
1717
ARCH_NAME: wormhole_b0
18-
TT_METAL_HOME: ${pwd}
19-
PYTHONPATH: ${pwd}
2018
runs-on: ["in-service", "n150"]
2119
steps:
2220
- name: Checkout Repo
@@ -35,4 +33,4 @@ jobs:
3533
pytest_report_title: "⭐️ Pytest Results ⭐️"
3634
run: |
3735
source venv/bin/activate
38-
python3 -m pytest --github-report tests/*.py
36+
python3 -m pytest --github-report tests/*.py -s

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,11 @@ The `*_total_*_size_dist/` statistics the `op_type`'s input/output_size distribu
5151
- Notice: the [aten ir interface is in there](https://pytorch.org/docs/stable/torch.compiler_ir.html)
5252

5353
[The `profile/` is the tools provided by pytorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html), you can open it by the url: chrome://tracing
54+
55+
# Run transformer models
56+
To run transformer model with ttnn backend, run:
57+
```
58+
PYTHONPATH=${TT_METAL_HOME}:$(pwd) python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
59+
```
60+
61+
You can also substitute the backend with `torch_stat` to run a reference comparison.

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
--find-links https://download.pytorch.org/whl/torch_stable.html
2+
13
torch==2.2.1.0+cpu
24
torchvision==0.17.1+cpu
35
tabulate==0.9.0
46
networkx==3.1
57
graphviz
6-
matplotlib
8+
matplotlib==3.7.1
9+
https://github.com/tenstorrent/tt-metal/releases/download/v0.50.0-rc18/metal_libs-0.50.0rc18+wormhole.b0-cp38-cp38-linux_x86_64.whl

tests/test_cse.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def input_shapes(self):
1919
class TestModules(unittest.TestCase):
2020
def setUp(self):
2121
# Open device 0
22-
self.device: ttnn.Device = ttnn.open(0)
22+
self.device: ttnn.Device = ttnn.open_device(device_id=0)
2323

2424
def tearDown(self):
2525
# Close the device
26-
ttnn.close(self.device)
26+
ttnn.close_device(self.device)
2727

2828
def test_add(self):
2929
m = AddModule()
@@ -32,19 +32,21 @@ def test_add(self):
3232
result_before = m.forward(*inputs)
3333
option = torch_ttnn.TorchTtnnOption(device=self.device)
3434
# The compilation is lazy, so we need to run forward once to trigger the compilation
35-
m = torch.compile(m, backend=torch_ttnn.backend(option))
35+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
3636
result_after = m.forward(*inputs)
3737
self.assertEqual(1, len(option._out_fx_graphs))
3838
option._out_fx_graphs[0].print_tabular()
3939
# Check the graph has be rewritten and contain ttnn ops
4040
nodes = list(option._out_fx_graphs[0].nodes)
41-
self.assertEqual(nodes[3].target, ttnn.add)
42-
self.assertEqual(nodes[3].args[0].target, ttnn.to_device)
43-
self.assertEqual(nodes[3].args[0].args[0].target, ttnn.from_torch)
44-
self.assertEqual(nodes[3].args[1].target, ttnn.to_device)
45-
self.assertEqual(nodes[3].args[1].args[0].target, ttnn.from_torch)
46-
self.assertEqual(nodes[4].target, ttnn.from_device)
47-
self.assertEqual(nodes[5].target, ttnn.to_layout)
48-
self.assertEqual(nodes[6].target, ttnn.to_torch)
41+
self.assertEqual(nodes[4].target, ttnn.add)
42+
self.assertEqual(nodes[4].args[0].target, ttnn.to_device)
43+
self.assertEqual(nodes[4].args[0].args[0].target, ttnn.to_layout)
44+
self.assertEqual(nodes[4].args[0].args[0].args[0].target, ttnn.from_torch)
45+
self.assertEqual(nodes[4].args[1].target, ttnn.to_device)
46+
self.assertEqual(nodes[4].args[1].args[0].target, ttnn.to_layout)
47+
self.assertEqual(nodes[4].args[1].args[0].args[0].target, ttnn.from_torch)
48+
self.assertEqual(nodes[5].target, ttnn.from_device)
49+
self.assertEqual(nodes[6].target, ttnn.to_layout)
50+
self.assertEqual(nodes[7].target, ttnn.to_torch)
4951
# Check inference result
5052
self.assertTrue(torch.allclose(result_before, result_after))

tests/test_fall_back.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import unittest
44
from torch_ttnn import ttnn
55

6+
from torch_ttnn.utils import check_with_pcc
7+
68

79
class MixModule(torch.nn.Module):
810
def __init__(self):
@@ -23,11 +25,11 @@ def input_shapes(self):
2325
class TestModules(unittest.TestCase):
2426
def setUp(self):
2527
# Open device 0
26-
self.device: ttnn.Device = ttnn.open(0)
28+
self.device: ttnn.Device = ttnn.open_device(device_id=0)
2729

2830
def tearDown(self):
2931
# Close the device
30-
ttnn.close(self.device)
32+
ttnn.close_device(self.device)
3133

3234
def test_fall_back(self):
3335
m = MixModule()
@@ -37,19 +39,29 @@ def test_fall_back(self):
3739
option = torch_ttnn.TorchTtnnOption(device=self.device)
3840
option.gen_graphviz = True
3941
# The compilation is lazy, so we need to run forward once to trigger the compilation
40-
m = torch.compile(m, backend=torch_ttnn.backend(option))
42+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
4143
result_after = m.forward(*inputs)
4244
self.assertEqual(1, len(option._out_fx_graphs))
4345
option._out_fx_graphs[0].print_tabular()
4446

4547
# Check the graph has be rewritten and contain ttnn ops
4648
nodes = list(option._out_fx_graphs[0].nodes)
47-
self.assertEqual(nodes[3].target, ttnn.from_torch)
49+
self.assertEqual(nodes[2].target, ttnn.from_torch)
50+
self.assertEqual(nodes[3].target, ttnn.to_layout)
4851
self.assertEqual(nodes[4].target, ttnn.to_device)
49-
self.assertEqual(nodes[5].target, ttnn.add)
50-
self.assertEqual(nodes[6].target, ttnn.matmul)
51-
self.assertEqual(nodes[7].target, ttnn.from_device)
52-
self.assertEqual(nodes[8].target, ttnn.to_layout)
53-
self.assertEqual(nodes[9].target, ttnn.to_torch)
52+
self.assertEqual(nodes[5].target, ttnn.reciprocal)
53+
self.assertEqual(nodes[6].target, ttnn.from_torch)
54+
self.assertEqual(nodes[7].target, ttnn.to_layout)
55+
self.assertEqual(nodes[8].target, ttnn.to_device)
56+
self.assertEqual(nodes[9].target, ttnn.mul)
57+
self.assertEqual(nodes[10].target, ttnn.add)
58+
self.assertEqual(nodes[11].target, ttnn.matmul)
59+
self.assertEqual(nodes[12].target, ttnn.reciprocal)
60+
self.assertEqual(nodes[13].target, ttnn.mul)
61+
self.assertEqual(nodes[14].target, ttnn.reciprocal)
62+
self.assertEqual(nodes[15].target, ttnn.mul)
63+
self.assertEqual(nodes[16].target, ttnn.from_device)
64+
self.assertEqual(nodes[17].target, ttnn.to_layout)
65+
self.assertEqual(nodes[18].target, ttnn.to_torch)
5466
# Check inference result
55-
self.assertTrue(torch.allclose(result_before, result_after))
67+
self.assertTrue(check_with_pcc(result_before, result_after))

tests/test_if.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def input_shapes(self):
2222
class TestModules(unittest.TestCase):
2323
def setUp(self):
2424
# Open device 0
25-
self.device: ttnn.Device = ttnn.open(0)
25+
self.device: ttnn.Device = ttnn.open_device(device_id=0)
2626

2727
def tearDown(self):
2828
# Close the device
29-
ttnn.close(self.device)
29+
ttnn.close_device(self.device)
3030

3131
def test_if(self):
3232
m = IfModule()
@@ -36,7 +36,7 @@ def test_if(self):
3636
result_before_else = m.forward(*inputs_else)
3737
option = torch_ttnn.TorchTtnnOption(device=self.device)
3838
# The compilation is lazy, so we need to run forward once to trigger the compilation
39-
m = torch.compile(m, backend=torch_ttnn.backend(option))
39+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
4040
result_after_then = m.forward(*inputs_then)
4141
result_after_else = m.forward(*inputs_else)
4242

@@ -49,21 +49,23 @@ def test_if(self):
4949
self.assertEqual(nodes_0[1].target, torch.ops.aten.sum.default)
5050
self.assertEqual(nodes_0[2].target, torch.ops.aten.gt.Scalar)
5151
nodes_1 = list(option._out_fx_graphs[1].nodes)
52-
self.assertEqual(len(nodes_1), 8)
52+
self.assertEqual(len(nodes_1), 9)
5353
self.assertEqual(nodes_1[1].target, ttnn.from_torch)
54-
self.assertEqual(nodes_1[2].target, ttnn.to_device)
55-
self.assertEqual(nodes_1[3].target, ttnn.add)
56-
self.assertEqual(nodes_1[4].target, ttnn.from_device)
57-
self.assertEqual(nodes_1[5].target, ttnn.to_layout)
58-
self.assertEqual(nodes_1[6].target, ttnn.to_torch)
54+
self.assertEqual(nodes_1[2].target, ttnn.to_layout)
55+
self.assertEqual(nodes_1[3].target, ttnn.to_device)
56+
self.assertEqual(nodes_1[4].target, ttnn.add)
57+
self.assertEqual(nodes_1[5].target, ttnn.from_device)
58+
self.assertEqual(nodes_1[6].target, ttnn.to_layout)
59+
self.assertEqual(nodes_1[7].target, ttnn.to_torch)
5960
nodes_2 = list(option._out_fx_graphs[2].nodes)
60-
self.assertEqual(len(nodes_2), 8)
61+
self.assertEqual(len(nodes_2), 9)
6162
self.assertEqual(nodes_2[1].target, ttnn.from_torch)
62-
self.assertEqual(nodes_2[2].target, ttnn.to_device)
63-
self.assertEqual(nodes_2[3].target, ttnn.matmul)
64-
self.assertEqual(nodes_2[4].target, ttnn.from_device)
65-
self.assertEqual(nodes_2[5].target, ttnn.to_layout)
66-
self.assertEqual(nodes_2[6].target, ttnn.to_torch)
63+
self.assertEqual(nodes_2[2].target, ttnn.to_layout)
64+
self.assertEqual(nodes_2[3].target, ttnn.to_device)
65+
self.assertEqual(nodes_2[4].target, ttnn.matmul)
66+
self.assertEqual(nodes_2[5].target, ttnn.from_device)
67+
self.assertEqual(nodes_2[6].target, ttnn.to_layout)
68+
self.assertEqual(nodes_2[7].target, ttnn.to_torch)
6769

6870
# Check inference result
6971
self.assertTrue(torch.allclose(result_before_then, result_after_then))

0 commit comments

Comments
 (0)