Skip to content

Commit 68a590d

Browse files
authored
Fix model breakages (#53)
* Consolidate metadata during op conversion * Unmark gpt2 and mnist test models to expect passing * Disable conversion from aten._to_copy * Pass device for all from_torch ops * Replace aten.full op to a literal scalar for certain cases * Compare only Tensor types for dictionary outputs * Replace aten.view with aten.reshape * Unmark bloom, llama, and yolos from xfail * Add conversion for aten.min * Add exception to aten.eq conversion * Fix reusing ttnn data movement op if mixed with aten ops * Convert all inputs to ttnn.bfloat16 when moving data in * Skip unsqueeze transformation if last dim of input is not the same as last dim of output * Add exception to aten.expand conversion when last dimension of input is 1 * Support list type arguments * Check layout change for ttnn reshape and embedding op * Freeze encoder for llama model * Add workaround for ttnn.permute when dim 0 is 1 for rank 3 * Reconvert int64 types from metadata when mixing ttnn and aten ops * Check for valid page size for ops that decompose to ttnn.full * Delete aten.expand op if output has the exact same shape * Mark GPT-2 model as xfail * Update README with new model stats * Fix output type of aten.arange unit test to match output of original * Disable to_copy unit test to re-evaluate conversion * Lower pcc for addmm slightly * Change input shapes of some unit test to match exceptions in current state of lowering * Fix page size validation for conversions involving ttnn.full ops * Update README * Revert changes to GPT-2 since it isn't working in this PR * Remove commented out code * Revert "Pass device for all from_torch ops" This reverts commit 775fb9f.
1 parent 518d524 commit 68a590d

File tree

13 files changed

+449
-218
lines changed

13 files changed

+449
-218
lines changed

README.md

Lines changed: 119 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ This project allows to run PyTorch code on [Tenstorrent](https://tenstorrent.com
88

99
The table below summarizes the results of running various ML models through our TTNN compiler. For each model, we track whether the run was successful, the number of operations before and after conversion, the number of `to_device` and `from_device` operations, performance metrics, and accuracy.
1010

11-
| Model | Run Success | Torch Ops Before (Unique Ops) | Torch Ops Remain (Unique Ops) | To/From Device Ops | Original Run Time (ms) | Compiled Run Time (ms) | Accuracy (%) |
12-
|:------------------------------------|:--------------|:--------------------------------|:--------------------------------|:---------------------|-------------------------:|:-------------------------|:---------------|
13-
| [Mnist (Eval)](tests/models/mnist) | | 14 (8) | 5 (4) | 16 | 36.12 | N/A | N/A |
14-
| [Mnist (Train)](tests/models/mnist) || 14 (8) | 7 (5) | 14 | 114.49 | 2742.8 | 81.75 |
15-
| [ResNet18](tests/models/resnet) || 70 (9) | 42 (4) | 47 | 2094.6 | 10950.18 | 99.99 |
16-
| [Bloom](tests/models/bloom) | | 1407 (29) | N/A | N/A | 9127.68 | N/A | N/A |
17-
| [YOLOS](tests/models/yolos) | | 964 (28) | N/A | N/A | 1353.22 | N/A | N/A |
18-
| [Llama](tests/models/llama) | | 3 (3) | 1 (1) | 5 | 52926.3 | N/A | N/A |
19-
| [BERT](tests/models/bert) || 1393 (21) | 537 (4) | 1607 | 65342 | 61028.65 | 98.64 |
20-
| [Falcon](tests/models/falcon) || 3 (3) | 1 (1) | 5 | 47738.8 | N/A | N/A |
21-
| [GPT-2](tests/models/gpt2) || 748 (31) | N/A | N/A | 2287.61 | N/A | N/A |
11+
| Model | Run Success | Torch Ops Before (Unique Ops) | Torch Ops Remain (Unique Ops) | To/From Device Ops | Original Run Time (ms) | Compiled Run Time (ms) | Accuracy (%) |
12+
|:------------------------------------|:--------------|:--------------------------------|:--------------------------------|---------------------:|-------------------------:|:-------------------------|:---------------|
13+
| [Mnist (Eval)](tests/models/mnist) | | 14 (8) | 5 (4) | 16 | 38.64 | 501.5 | 99.85 |
14+
| [Mnist (Train)](tests/models/mnist) || 14 (8) | 7 (5) | 14 | 136.38 | 2709.01 | 66.84 |
15+
| [ResNet18](tests/models/resnet) || 70 (9) | 42 (4) | 47 | 2131.05 | 9985.44 | 99.99 |
16+
| [Bloom](tests/models/bloom) | | 1407 (29) | 626 (11) | 1379 | 28892.3 | 68470.67 | 45.77 |
17+
| [YOLOS](tests/models/yolos) | | 964 (28) | 409 (11) | 919 | 1410.28 | 45328.58 | 71.71 |
18+
| [Llama](tests/models/llama) | | 5 (4) | 3 (2) | 3 | 206771 | 187910.29 | 45.46 |
19+
| [BERT](tests/models/bert) || 1393 (21) | 539 (5) | 1513 | 67347.3 | 60024.8 | 98.64 |
20+
| [Falcon](tests/models/falcon) || 3 (3) | 2 (2) | 5 | 51366.6 | N/A | N/A |
21+
| [GPT-2](tests/models/gpt2) || 748 (31) | 316 (12) | 569 | 5711.32 | N/A | N/A |
2222

2323
### Explanation of Metrics
2424

@@ -47,7 +47,7 @@ The table below summarizes the results of running various ML models through our
4747
| aten.max_pool2d_with_indices.default || 1 |
4848
| aten.relu.default || 3 |
4949
| aten.t.default || 2 |
50-
| aten.view.default | | 1 |
50+
| aten.view.default | | 1 |
5151
#### Mnist (Train)
5252
| aten ops | status | count |
5353
|:-------------------------------------|:---------|--------:|
@@ -58,7 +58,7 @@ The table below summarizes the results of running various ML models through our
5858
| aten.native_dropout.default || 2 |
5959
| aten.relu.default || 3 |
6060
| aten.t.default || 2 |
61-
| aten.view.default | | 1 |
61+
| aten.view.default | | 1 |
6262
#### ResNet18
6363
| aten ops | status | count |
6464
|:--------------------------------------------------|:---------|--------:|
@@ -70,18 +70,82 @@ The table below summarizes the results of running various ML models through our
7070
| aten.mean.dim || 1 |
7171
| aten.relu.default || 17 |
7272
| aten.t.default || 1 |
73-
| aten.view.default || 1 |
73+
| aten.view.default || 1 |
74+
#### Bloom
75+
| aten ops | status | count |
76+
|:-------------------------------|:---------|--------:|
77+
| aten._softmax.default || 24 |
78+
| aten._to_copy.default || 54 |
79+
| aten._unsafe_view.default || 24 |
80+
| aten.add.Tensor || 96 |
81+
| aten.addmm.default || 96 |
82+
| aten.arange.start || 1 |
83+
| aten.baddbmm.default || 24 |
84+
| aten.bmm.default || 24 |
85+
| aten.clone.default || 96 |
86+
| aten.cumsum.default || 1 |
87+
| aten.embedding.default || 1 |
88+
| aten.expand.default || 2 |
89+
| aten.full.default || 1 |
90+
| aten.lift_fresh_copy.default || 1 |
91+
| aten.masked_fill.Scalar || 26 |
92+
| aten.mm.default || 1 |
93+
| aten.mul.Tensor || 146 |
94+
| aten.native_layer_norm.default || 50 |
95+
| aten.permute.default || 48 |
96+
| aten.pow.Tensor_Tensor || 1 |
97+
| aten.rsub.Scalar || 1 |
98+
| aten.select.int || 72 |
99+
| aten.slice.Tensor || 78 |
100+
| aten.sub.Tensor || 1 |
101+
| aten.t.default || 97 |
102+
| aten.tanh.default || 24 |
103+
| aten.transpose.int || 48 |
104+
| aten.unsqueeze.default || 6 |
105+
| aten.view.default || 363 |
106+
#### YOLOS
107+
| aten ops | status | count |
108+
|:-------------------------------|:---------|--------:|
109+
| aten._softmax.default || 12 |
110+
| aten._to_copy.default || 2 |
111+
| aten._unsafe_index.Tensor || 16 |
112+
| aten.add.Tensor || 71 |
113+
| aten.addmm.default || 78 |
114+
| aten.arange.default || 4 |
115+
| aten.bmm.default || 24 |
116+
| aten.cat.default || 2 |
117+
| aten.clamp.default || 32 |
118+
| aten.clone.default || 50 |
119+
| aten.convolution.default || 1 |
120+
| aten.div.Tensor || 12 |
121+
| aten.expand.default || 50 |
122+
| aten.floor.default || 2 |
123+
| aten.gelu.default || 12 |
124+
| aten.mul.Tensor || 82 |
125+
| aten.native_layer_norm.default || 25 |
126+
| aten.permute.default || 48 |
127+
| aten.relu.default || 4 |
128+
| aten.rsub.Scalar || 10 |
129+
| aten.select.int || 1 |
130+
| aten.sigmoid.default || 1 |
131+
| aten.slice.Tensor || 12 |
132+
| aten.sub.Tensor || 36 |
133+
| aten.t.default || 78 |
134+
| aten.transpose.int || 15 |
135+
| aten.unsqueeze.default || 1 |
136+
| aten.view.default || 283 |
74137
#### Llama
75-
| aten ops | status | count |
76-
|:-----------------------|:---------|--------:|
77-
| aten.arange.start || 1 |
78-
| aten.embedding.default || 1 |
79-
| aten.unsqueeze.default || 1 |
138+
| aten ops | status | count |
139+
|:----------------------|:---------|--------:|
140+
| aten._to_copy.default || 1 |
141+
| aten.mm.default || 1 |
142+
| aten.t.default || 1 |
143+
| aten.view.default || 2 |
80144
#### BERT
81145
| aten ops | status | count |
82146
|:-------------------------------|:---------|--------:|
83147
| aten._softmax.default || 24 |
84-
| aten._to_copy.default | | 1 |
148+
| aten._to_copy.default | | 1 |
85149
| aten.add.Tensor || 74 |
86150
| aten.addmm.default || 145 |
87151
| aten.bmm.default || 48 |
@@ -100,13 +164,47 @@ The table below summarizes the results of running various ML models through our
100164
| aten.t.default || 145 |
101165
| aten.transpose.int || 24 |
102166
| aten.unsqueeze.default || 2 |
103-
| aten.view.default | | 530 |
167+
| aten.view.default | | 530 |
104168
#### Falcon
105169
| aten ops | status | count |
106170
|:-----------------------|:---------|--------:|
107171
| aten.arange.start || 1 |
108172
| aten.embedding.default || 1 |
109173
| aten.unsqueeze.default || 1 |
174+
#### GPT-2
175+
| aten ops | status | count |
176+
|:-------------------------------|:---------|--------:|
177+
| aten._softmax.default || 12 |
178+
| aten._to_copy.default || 2 |
179+
| aten.add.Tensor || 61 |
180+
| aten.addmm.default || 48 |
181+
| aten.arange.default || 1 |
182+
| aten.arange.start || 1 |
183+
| aten.argmax.default || 1 |
184+
| aten.bmm.default || 24 |
185+
| aten.clone.default || 49 |
186+
| aten.div.Tensor || 12 |
187+
| aten.embedding.default || 2 |
188+
| aten.eq.Scalar || 1 |
189+
| aten.expand.default || 48 |
190+
| aten.full.default || 24 |
191+
| aten.index.Tensor || 1 |
192+
| aten.mm.default || 1 |
193+
| aten.mul.Tensor || 49 |
194+
| aten.native_layer_norm.default || 25 |
195+
| aten.permute.default || 48 |
196+
| aten.pow.Tensor_Scalar || 12 |
197+
| aten.remainder.Scalar || 1 |
198+
| aten.rsub.Scalar || 1 |
199+
| aten.slice.Tensor || 50 |
200+
| aten.split.Tensor || 12 |
201+
| aten.sub.Tensor || 1 |
202+
| aten.t.default || 1 |
203+
| aten.tanh.default || 12 |
204+
| aten.transpose.int || 12 |
205+
| aten.unsqueeze.default || 3 |
206+
| aten.view.default || 221 |
207+
| aten.where.self || 12 |
110208

111209

112210
## Quickstart

tests/lowering/creation/test_arange.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def forward(self, start, end, step):
3838
)
3939
def test_arange(device, input_shapes):
4040
m = ArangeModule()
41-
result_before = m.forward(*input_shapes).to(torch.bfloat16)
41+
result_before = m.forward(*input_shapes)
4242
option = torch_ttnn.TorchTtnnOption(device=device)
4343
option.gen_graphviz = True
4444
# The compilation is lazy, so we need to run forward once to trigger the compilation
@@ -59,7 +59,7 @@ def test_arange(device, input_shapes):
5959
)
6060
def test_arange_start(device, input_shapes):
6161
m = ArangeStartModule()
62-
result_before = m.forward(*input_shapes).to(torch.bfloat16)
62+
result_before = m.forward(*input_shapes)
6363
option = torch_ttnn.TorchTtnnOption(device=device)
6464
option.gen_graphviz = True
6565
# The compilation is lazy, so we need to run forward once to trigger the compilation
@@ -80,7 +80,7 @@ def test_arange_start(device, input_shapes):
8080
)
8181
def test_arange_start_step(device, input_shapes):
8282
m = ArangeStartStepModule()
83-
result_before = m.forward(*input_shapes).to(torch.bfloat16)
83+
result_before = m.forward(*input_shapes)
8484
option = torch_ttnn.TorchTtnnOption(device=device)
8585
option.gen_graphviz = True
8686
# The compilation is lazy, so we need to run forward once to trigger the compilation

tests/lowering/creation/test_to_copy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def forward(self, x):
2121
return torch.add(to, to)
2222

2323

24+
# aten.to_copy is used to convert a dtype to another.
25+
# TODO: Will need to re-evaluate the conversion.
26+
@pytest.mark.xfail
2427
@pytest.mark.parametrize(
2528
"input_shapes",
2629
[[(4, 4)]],
@@ -43,6 +46,7 @@ def test_to_copy(device, input_shapes):
4346
assert torch.allclose(result_before, result_after, rtol=0.2)
4447

4548

49+
@pytest.mark.xfail
4650
@pytest.mark.parametrize(
4751
"input_shapes",
4852
[[(4, 4)]],

tests/lowering/eltwise/binary/test_div.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ def forward(self, numerator, denominator):
1515

1616
@pytest.mark.parametrize(
1717
"input_shapes",
18-
[[(4, 4), (4, 4)]],
18+
[[(4, 4), (4, 4)], [(64, 128), (64, 128)]],
1919
)
2020
def test_div(device, input_shapes):
2121
m = DivModule()
22-
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
22+
inputs = [torch.randint(1, 15, shape).to(torch.bfloat16) for shape in input_shapes]
2323
result_before = m.forward(*inputs)
2424
option = torch_ttnn.TorchTtnnOption(device=device)
2525
option.gen_graphviz = True
@@ -45,7 +45,7 @@ def test_div(device, input_shapes):
4545

4646
@pytest.mark.parametrize(
4747
"input_shapes",
48-
[[(4, 4)]],
48+
[[(4, 4)], [(32, 32)]],
4949
)
5050
def test_div_scalar_denom(device, input_shapes):
5151
m = DivModule()

tests/lowering/eltwise/binary/test_sub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_rsub(device, input_shapes):
8080

8181
@pytest.mark.parametrize(
8282
"input_shapes",
83-
[[(4, 4)]],
83+
[[(4, 4)], [(32, 32)]],
8484
)
8585
def test_rsub_scalar(device, input_shapes):
8686
m = RSubScalarModule()

tests/lowering/matmul/test_addmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ def test_addmm(device, input_shapes):
4141
if node.target == ttnn.matmul:
4242
assert node.meta["val"].size() == input_shapes[0]
4343
# Check inference result
44-
assert_with_pcc(result_before, result_after)
44+
assert_with_pcc(result_before, result_after, pcc=0.999)

tests/models/bloom/test_bloom.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from transformers import AutoTokenizer, AutoModelForCausalLM
66

77

8-
@pytest.mark.xfail
98
def test_bloom(record_property):
109
record_property("model_name", "Bloom")
1110

@@ -19,7 +18,14 @@ def test_bloom(record_property):
1918

2019
# Set up sample input
2120
test_input = "This is a sample text from "
22-
inputs = tokenizer(test_input, return_tensors="pt")
21+
inputs = tokenizer.encode_plus(
22+
test_input,
23+
return_tensors="pt",
24+
max_length=32,
25+
padding="max_length",
26+
add_special_tokens=True,
27+
truncation=True,
28+
)
2329

2430
# Run inference with the original model
2531
with torch.no_grad():

tests/models/llama/test_llama.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from transformers import AutoTokenizer, AutoModelForCausalLM
66

77

8-
@pytest.mark.xfail
98
def test_llama(record_property):
109
record_property("model_name", "Llama")
1110

@@ -14,12 +13,22 @@ def test_llama(record_property):
1413
tokenizer = AutoTokenizer.from_pretrained(
1514
model_name, padding_side="left", torch_dtype=torch.bfloat16
1615
)
16+
tokenizer.pad_token = tokenizer.eos_token
1717
m = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
18+
for param in m.parameters():
19+
param.requires_grad = False
1820
m.eval()
1921

2022
# Set up sample input
2123
test_input = "This is a sample text from "
22-
inputs = tokenizer(test_input, return_tensors="pt")
24+
inputs = tokenizer.encode_plus(
25+
test_input,
26+
return_tensors="pt",
27+
max_length=32,
28+
padding="max_length",
29+
add_special_tokens=True,
30+
truncation=True,
31+
)
2332

2433
# Run inference with the original model
2534
with torch.no_grad():

tests/models/mnist/test_mnist.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def test_mnist_train(record_property):
5858
record_property("torch_ttnn", (m, test_input, outputs))
5959

6060

61-
@pytest.mark.xfail
6261
def test_mnist_eval(record_property):
6362
record_property("model_name", "Mnist (Eval)")
6463

tests/models/yolos/test_yolos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from transformers import AutoImageProcessor, AutoModelForObjectDetection
88

99

10-
@pytest.mark.xfail
10+
# @pytest.mark.xfail
1111
def test_yolos(record_property):
1212
record_property("model_name", "YOLOS")
1313

0 commit comments

Comments
 (0)