Skip to content

Commit 8ebfec2

Browse files
committed
Move torch_ttnn compilation process away from individual models
1 parent 54e45e3 commit 8ebfec2

File tree

14 files changed

+219
-342
lines changed

14 files changed

+219
-342
lines changed

README.md

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ The table below summarizes the results of running various ML models through our
1010

1111
| Model | Run Success | Torch Ops Before (Unique Ops) | Torch Ops Remain (Unique Ops) | To/From Device Ops | Original Run Time (ms) | Compiled Run Time (ms) | Accuracy (%) |
1212
|:------------------------------------|:--------------|:--------------------------------|:--------------------------------|:---------------------|-------------------------:|:-------------------------|:---------------|
13-
| [Mnist (Eval)](tests/models/mnist) || 14 (8) | 5 (4) | 12 | 11.04 | N/A | N/A |
14-
| [Mnist (Train)](tests/models/mnist) || 14 (8) | 7 (5) | 14 | 18.01 | 2922.51 | 85.88 |
15-
| [ResNet18](tests/models/resnet) || 70 (9) | 42 (4) | 45 | 1772.4 | 8398.87 | 99.99 |
16-
| [Bloom](tests/models/bloom) || 1407 (29) | N/A | N/A | 5602.6 | N/A | N/A |
17-
| [YOLOS](tests/models/yolos) || 964 (28) | N/A | N/A | 209.04 | N/A | N/A |
18-
| [Llama](tests/models/llama) || 3 (3) | 1 (1) | 5 | 38255.4 | N/A | N/A |
19-
| [BERT](tests/models/bert) || 1393 (21) | 537 (4) | 1388 | 61919.4 | 52814.88 | 98.64 |
20-
| [Falcon](tests/models/falcon) || 3 (3) | 1 (1) | 5 | 35014.3 | N/A | N/A |
21-
| [GPT-2](tests/models/gpt2) || 748 (31) | N/A | N/A | 1033.47 | N/A | N/A |
13+
| [Mnist (Eval)](tests/models/mnist) || 14 (8) | 5 (4) | 16 | 36.12 | N/A | N/A |
14+
| [Mnist (Train)](tests/models/mnist) || 14 (8) | 7 (5) | 14 | 114.49 | 2742.8 | 81.75 |
15+
| [ResNet18](tests/models/resnet) || 70 (9) | 42 (4) | 47 | 2094.6 | 10950.18 | 99.99 |
16+
| [Bloom](tests/models/bloom) || 1407 (29) | N/A | N/A | 9127.68 | N/A | N/A |
17+
| [YOLOS](tests/models/yolos) || 964 (28) | N/A | N/A | 1353.22 | N/A | N/A |
18+
| [Llama](tests/models/llama) || 3 (3) | 1 (1) | 5 | 52926.3 | N/A | N/A |
19+
| [BERT](tests/models/bert) || 1393 (21) | 537 (4) | 1607 | 65342 | 61028.65 | 98.64 |
20+
| [Falcon](tests/models/falcon) || 3 (3) | 1 (1) | 5 | 47738.8 | N/A | N/A |
21+
| [GPT-2](tests/models/gpt2) || 748 (31) | N/A | N/A | 2287.61 | N/A | N/A |
2222

2323
### Explanation of Metrics
2424

@@ -173,3 +173,26 @@ PYTHONPATH=${TT_METAL_HOME}:$(pwd) python3 tools/run_transformers.py --model "ph
173173

174174
You can also substitute the backend with `torch_stat` to run a reference comparison.
175175

176+
# Add a model test
177+
If you want to record run time metrics for a model or test, include a Pytest fixture named `record_property` as a parameter and set the "model_name" key.
178+
If you also want to compile the model with torch_ttnn backend, set the `torch_ttnn` key to a tuple in this order `(model, test_inputs, outputs)`. "model_name" still needs to be set. See the example code snippet below. Currently, only `torch.nn.Module` models with a `forward` function are supported.
179+
```
180+
def Model(torch.nn.Module):
181+
def forward(self, x):
182+
# ...
183+
return outputs
184+
185+
def test_model_name(record_property):
186+
# Should be set as early as possible
187+
record_property("model_name", "Model Name")
188+
189+
model = Model()
190+
# ...
191+
outputs = model(test_input)
192+
# outputs = model(**test_inputs) # dictionary inputs are also supported
193+
# ...
194+
195+
# Can be set once all three objects for the tuple are defined
196+
record_property("torch_ttnn", (model, test_input(s), outputs))
197+
```
198+

docs/README.md.in

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,26 @@ PYTHONPATH=${{TT_METAL_HOME}}:$(pwd) python3 tools/run_transformers.py --model "
8181
```
8282

8383
You can also substitute the backend with `torch_stat` to run a reference comparison.
84+
85+
# Add a model test
86+
If you want to record run time metrics for a model or test, include a Pytest fixture named `record_property` as a parameter and set the "model_name" key.
87+
If you also want to compile the model with torch_ttnn backend, set the `torch_ttnn` key to a tuple in this order `(model, test_inputs, outputs)`. "model_name" still needs to be set. See the example code snippet below. Currently, only `torch.nn.Module` models with a `forward` function are supported.
88+
```
89+
def Model(torch.nn.Module):
90+
def forward(self, x):
91+
# ...
92+
return outputs
93+
94+
def test_model_name(record_property):
95+
# Should be set as early as possible
96+
record_property("model_name", "Model Name")
97+
98+
model = Model()
99+
# ...
100+
outputs = model(test_input)
101+
# outputs = model(**test_inputs) # dictionary inputs are also supported
102+
# ...
103+
104+
# Can be set once all three objects for the tuple are defined
105+
record_property("torch_ttnn", (model, test_input(s), outputs))
106+
```

tests/conftest.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import pytest
22
import ttnn
33
import torch
4+
import torch_ttnn
5+
import collections
6+
from tests.utils import calculate_accuracy
7+
import time
8+
from pathlib import Path
9+
import os
10+
import pickle
411

512

613
@pytest.fixture(scope="session")
@@ -15,3 +22,61 @@ def reset_torch_dynamo():
1522
# PyTorch caches models. Start a fresh compile for each parameter of the test case.
1623
torch._dynamo.reset()
1724
yield
25+
26+
27+
@pytest.fixture(autouse=True)
28+
def compile_and_run(device, reset_torch_dynamo, request):
29+
try:
30+
start = time.perf_counter() * 1000
31+
yield
32+
end = time.perf_counter() * 1000
33+
runtime_metrics = {"success": True, "run_time": round(end - start, 2)}
34+
except Exception as e:
35+
runtime_metrics = {"success": False}
36+
print(f"{model_name} original failed to run. Raised exception: {e}")
37+
raise
38+
finally:
39+
record = dict(request.node.user_properties)
40+
model_path = Path(request.node.location[0])
41+
runtime_metrics["model_path"] = str(model_path.parent)
42+
if "model_name" in record:
43+
model_name = record["model_name"]
44+
p = Path(f"metrics/{model_name}")
45+
os.makedirs(p, exist_ok=True)
46+
47+
original_metrics_path = p / f"original-run_time_metrics.pickle"
48+
with open(original_metrics_path, "wb") as f:
49+
pickle.dump(runtime_metrics, f)
50+
51+
if "torch_ttnn" in record:
52+
model, inputs, outputs = record["torch_ttnn"]
53+
try:
54+
# check that model contains a forward function
55+
assert "forward" in dir(model), f"forward() not implemented in {model_name}"
56+
# Compile model with ttnn backend
57+
option = torch_ttnn.TorchTtnnOption(
58+
device=device, gen_graphviz=True, metrics_path=model_name
59+
)
60+
m = torch.compile(model, backend=torch_ttnn.backend, options=option)
61+
62+
start = time.perf_counter() * 1000
63+
if isinstance(inputs, collections.Mapping):
64+
outputs_after = m(**inputs)
65+
elif isinstance(inputs, collections.Sequence):
66+
outputs_after = m(*inputs)
67+
else:
68+
outputs_after = m(inputs)
69+
end = time.perf_counter() * 1000
70+
comp_runtime_metrics = {"success": True, "run_time": round(end - start, 2)}
71+
option._out_fx_graphs[0].print_tabular()
72+
accuracy = calculate_accuracy(outputs, outputs_after)
73+
if accuracy:
74+
comp_runtime_metrics["accuracy"] = accuracy
75+
except Exception as e:
76+
comp_runtime_metrics = {"success": False}
77+
print(f"{model_name} compiled failed to run. Raised exception: {e}")
78+
raise
79+
finally:
80+
compiled_metrics_path = p / f"compiled-run_time_metrics.pickle"
81+
with open(compiled_metrics_path, "wb") as f:
82+
pickle.dump(comp_runtime_metrics, f)

tests/models/bert/test_bert.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
2-
import torch_ttnn
3-
import pytest
4-
from torch_ttnn.metrics import RunTimeMetrics
52

63
# Load model directly
74
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
85

96

10-
def test_bert(device):
7+
def test_bert(record_property):
8+
record_property("model_name", "BERT")
9+
1110
# Download model from cloud
1211
model_name = "phiyodr/bert-large-finetuned-squad2"
1312
tokenizer = AutoTokenizer.from_pretrained(
@@ -32,10 +31,9 @@ def test_bert(device):
3231
truncation=True,
3332
)
3433

35-
metrics_path = "BERT"
3634
# Run inference with the original model
3735
with torch.no_grad():
38-
outputs_before = RunTimeMetrics(metrics_path, "original", lambda: m(**inputs))
36+
outputs = m(**inputs)
3937

4038
# Helper function to decode output to human-readable text
4139
def decode_output(outputs):
@@ -44,34 +42,16 @@ def decode_output(outputs):
4442
response_tokens = inputs.input_ids[0, response_start:response_end]
4543
return tokenizer.decode(response_tokens)
4644

47-
answer_before = decode_output(outputs_before)
48-
49-
# Compile model with ttnn backend
50-
option = torch_ttnn.TorchTtnnOption(
51-
device=device, gen_graphviz=True, metrics_path=metrics_path
52-
)
53-
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
54-
55-
# Run inference with the compiled model
56-
with torch.no_grad():
57-
outputs_after = RunTimeMetrics(metrics_path, "compiled", lambda: m(**inputs))
58-
59-
option._out_fx_graphs[0].print_tabular()
60-
61-
answer_after = decode_output(outputs_after)
45+
answer = decode_output(outputs)
6246

6347
print(
6448
f"""
6549
model_name: {model_name}
6650
input:
6751
context: {context}
6852
question: {question}
69-
answer before: {answer_before}
70-
answer after: {answer_after}
53+
answer: {answer}
7154
"""
7255
)
7356

74-
# TODO: Add more checks for the compiled graph
75-
76-
# Check inference result
77-
assert answer_before == answer_after
57+
record_property("torch_ttnn", (m, inputs, outputs))

tests/models/bloom/test_bloom.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch
2-
import torch_ttnn
32
import pytest
4-
from torch_ttnn.metrics import RunTimeMetrics
53

64
# Load model directly
75
from transformers import AutoTokenizer, AutoModelForCausalLM
86

97

108
@pytest.mark.xfail
11-
def test_bloom(device):
9+
def test_bloom(record_property):
10+
record_property("model_name", "Bloom")
11+
1212
# Download model from cloud
1313
model_name = "bigscience/bloom-1b1"
1414
tokenizer = AutoTokenizer.from_pretrained(
@@ -21,42 +21,24 @@ def test_bloom(device):
2121
test_input = "This is a sample text from "
2222
inputs = tokenizer(test_input, return_tensors="pt")
2323

24-
metrics_path = "Bloom"
2524
# Run inference with the original model
2625
with torch.no_grad():
27-
outputs_before = RunTimeMetrics(metrics_path, "original", lambda: m(**inputs))
26+
outputs = m(**inputs)
2827

2928
# Helper function to decode output to human-readable text
3029
def decode_output(outputs):
3130
next_token_logits = outputs.logits[:, -1]
3231
next_token = next_token_logits.softmax(dim=-1).argmax()
3332
return tokenizer.decode([next_token])
3433

35-
decoded_output_before = decode_output(outputs_before)
36-
37-
# Compile model with ttnn backend
38-
option = torch_ttnn.TorchTtnnOption(
39-
device=device, gen_graphviz=True, metrics_path=metrics_path
40-
)
41-
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
42-
43-
# Run inference with the compiled model
44-
with torch.no_grad():
45-
outputs_after = RunTimeMetrics(metrics_path, "compiled", lambda: m(**inputs))
46-
option._out_fx_graphs[0].print_tabular()
47-
48-
decoded_output_after = decode_output(outputs_after)
34+
decoded_output = decode_output(outputs)
4935

5036
print(
5137
f"""
5238
model_name: {model_name}
5339
input: {test_input}
54-
output before: {decoded_output_before}
55-
output after: {decoded_output_after}
40+
output: {decoded_output}
5641
"""
5742
)
5843

59-
# TODO: Add more checks for the compiled graph
60-
61-
# Check inference result
62-
assert decoded_output_before == decoded_output_after
44+
record_property("torch_ttnn", (m, inputs, outputs))

tests/models/falcon/test_falcon.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch
2-
import torch_ttnn
32
import pytest
4-
from torch_ttnn.metrics import RunTimeMetrics
53

64
# Load model directly
75
from transformers import AutoTokenizer, AutoModelForCausalLM
86

97

108
@pytest.mark.xfail
11-
def test_falcon(device):
9+
def test_falcon(record_property):
10+
record_property("model_name", "Falcon")
11+
1212
# Download model from cloud
1313
model_name = "tiiuae/falcon-7b-instruct"
1414
tokenizer = AutoTokenizer.from_pretrained(
@@ -21,43 +21,24 @@ def test_falcon(device):
2121
test_input = "This is a sample text from "
2222
inputs = tokenizer(test_input, return_tensors="pt")
2323

24-
metrics_path = "Falcon"
2524
# Run inference with the original model
2625
with torch.no_grad():
27-
outputs_before = RunTimeMetrics(metrics_path, "original", lambda: m(**inputs))
26+
outputs = m(**inputs)
2827

2928
# Helper function to decode output to human-readable text
3029
def decode_output(outputs):
3130
next_token_logits = outputs.logits[:, -1]
3231
next_token = next_token_logits.softmax(dim=-1).argmax()
3332
return tokenizer.decode([next_token])
3433

35-
decoded_output_before = decode_output(outputs_before)
36-
37-
# Compile model with ttnn backend
38-
option = torch_ttnn.TorchTtnnOption(
39-
device=device, gen_graphviz=True, metrics_path=metrics_path
40-
)
41-
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
42-
43-
# Run inference with the compiled model
44-
with torch.no_grad():
45-
outputs_after = RunTimeMetrics(metrics_path, "compiled", lambda: m(**inputs))
46-
47-
option._out_fx_graphs[0].print_tabular()
48-
49-
decoded_output_after = decode_output(outputs_after)
34+
decoded_output = decode_output(outputs)
5035

5136
print(
5237
f"""
5338
model_name: {model_name}
5439
input: {test_input}
55-
output before: {decoded_output_before}
56-
output after: {decoded_output_after}
40+
output before: {decoded_output}
5741
"""
5842
)
5943

60-
# TODO: Add more checks for the compiled graph
61-
62-
# Check inference result
63-
assert decoded_output_before == decoded_output_after
44+
record_property("torch_ttnn", (m, inputs, outputs))

0 commit comments

Comments
 (0)