Skip to content

Commit 4da248b

Browse files
authored
Add a QuestionAnswering bert model for testing (#19)
* Add matching transformers version as current tt-metal * Remove the type conversion passes and directly call transformers to use bfloat16 when initializing the model * Convert torch.Tensor.to to ttnn.as_tensor * Use ttnn.full correctly instead of aten.full for certain cases * Convert bert model to unittest * Move transformers installation to dev * Refactor model input and output print statement for test_bert
1 parent 4f8daaf commit 4da248b

File tree

9 files changed

+235
-202
lines changed

9 files changed

+235
-202
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
pytest==7.2.2
33
pytest-timeout==2.2.0
44
pre-commit==3.0.4
5+
transformers==4.38.0

tests/test_bert.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
import torch_ttnn
3+
import unittest
4+
from torch_ttnn import ttnn
5+
import collections
6+
7+
# Load model directly
8+
from transformers import (
9+
AutoTokenizer,
10+
AutoModelForQuestionAnswering,
11+
)
12+
13+
14+
class TestBert(unittest.TestCase):
15+
def setUp(self):
16+
# Open device 0
17+
self.device: ttnn.Device = ttnn.open_device(device_id=0)
18+
19+
def tearDown(self):
20+
# Close the device
21+
ttnn.close_device(self.device)
22+
23+
def test_bert(self):
24+
# Download model from cloud
25+
model_name = "phiyodr/bert-large-finetuned-squad2"
26+
tokenizer = AutoTokenizer.from_pretrained(
27+
model_name, padding_side="left", torch_dtype=torch.bfloat16
28+
)
29+
m = AutoModelForQuestionAnswering.from_pretrained(
30+
model_name, torch_dtype=torch.bfloat16
31+
)
32+
m.eval()
33+
34+
# Set up sample input
35+
context = 'Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. "The prophet and founding hero of modern archaeology", Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art. '
36+
question = "What discipline did Winkelmann create?"
37+
38+
inputs = tokenizer.encode_plus(
39+
question,
40+
context,
41+
add_special_tokens=True,
42+
return_tensors="pt",
43+
max_length=256,
44+
padding="max_length",
45+
truncation=True,
46+
)
47+
48+
# Run inference with the original model
49+
with torch.no_grad():
50+
outputs_before = m(**inputs)
51+
52+
# Helper function to decode output to human-readable text
53+
def decode_output(outputs):
54+
response_start = torch.argmax(outputs.start_logits)
55+
response_end = torch.argmax(outputs.end_logits) + 1
56+
response_tokens = inputs.input_ids[0, response_start:response_end]
57+
return tokenizer.decode(response_tokens)
58+
59+
answer_before = decode_output(outputs_before)
60+
61+
# Compile model with ttnn backend
62+
option = torch_ttnn.TorchTtnnOption(device=self.device)
63+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
64+
65+
# Run inference with the compiled model
66+
with torch.no_grad():
67+
outputs_after = m(**inputs)
68+
option._out_fx_graphs[0].print_tabular()
69+
70+
answer_after = decode_output(outputs_after)
71+
72+
print(
73+
f"""
74+
model_name: {model_name}
75+
input:
76+
context: {context}
77+
question: {question}
78+
answer before: {answer_before}
79+
answer after: {answer_after}
80+
"""
81+
)
82+
83+
# TODO: Add more checks for the compiled graph
84+
85+
# Check inference result
86+
self.assertEqual(answer_before, answer_after)

tests/test_more_ops.py

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,29 @@ def input_shapes(self):
555555
return [(4, 4)]
556556

557557

558+
class ToCopyModule(torch.nn.Module):
559+
def __init__(self):
560+
super().__init__()
561+
562+
def forward(self, x):
563+
return x.to(torch.bfloat16)
564+
565+
def input_shapes(self):
566+
return [(4, 4)]
567+
568+
569+
class ToCopyWithOpAfterModule(torch.nn.Module):
570+
def __init__(self):
571+
super().__init__()
572+
573+
def forward(self, x):
574+
to = x.to(torch.bfloat16)
575+
return torch.add(to, to)
576+
577+
def input_shapes(self):
578+
return [(4, 4)]
579+
580+
558581
class TestModules(unittest.TestCase):
559582
def setUp(self):
560583
# Open device 0
@@ -894,18 +917,17 @@ def test_div_scalar_denom(self):
894917

895918
# Check the graph has be rewritten and contain ttnn ops
896919
nodes = list(option._out_fx_graphs[0].nodes)
897-
self.assertTrue(nodes[5].target == ttnn.reciprocal)
898-
self.assertTrue(nodes[5].args[0].target == ttnn.to_device)
899-
self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout)
900-
self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch)
901-
self.assertTrue(nodes[9].target == ttnn.mul)
902-
self.assertTrue(nodes[9].args[0].target == ttnn.to_device)
903-
self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout)
904-
self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch)
905-
self.assertTrue(nodes[9].args[1].target == ttnn.reciprocal)
906-
self.assertTrue(nodes[10].target == ttnn.from_device)
907-
self.assertTrue(nodes[11].target == ttnn.to_layout)
908-
self.assertTrue(nodes[12].target == ttnn.to_torch)
920+
self.assertTrue(nodes[1].target == ttnn.full)
921+
self.assertTrue(nodes[3].target == ttnn.reciprocal)
922+
self.assertTrue(nodes[3].args[0].target == ttnn.to_layout)
923+
self.assertTrue(nodes[7].target == ttnn.mul)
924+
self.assertTrue(nodes[7].args[0].target == ttnn.to_device)
925+
self.assertTrue(nodes[7].args[0].args[0].target == ttnn.to_layout)
926+
self.assertTrue(nodes[7].args[0].args[0].args[0].target == ttnn.from_torch)
927+
self.assertTrue(nodes[7].args[1].target == ttnn.reciprocal)
928+
self.assertTrue(nodes[8].target == ttnn.from_device)
929+
self.assertTrue(nodes[9].target == ttnn.to_layout)
930+
self.assertTrue(nodes[10].target == ttnn.to_torch)
909931
# Check inference result
910932
self.assertTrue(check_with_pcc(result_before, result_after))
911933

@@ -972,14 +994,15 @@ def test_rsub_scalar(self):
972994

973995
# Check the graph has be rewritten and contain ttnn ops
974996
nodes = list(option._out_fx_graphs[0].nodes)
975-
# self.aseertTrue(nodes[1].target == ttnn.full)
976-
self.assertTrue(nodes[8].target == ttnn.sub)
977-
self.assertTrue(nodes[8].args[0].target == ttnn.to_device)
978-
self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout)
979-
self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch)
980-
self.assertTrue(nodes[9].target == ttnn.from_device)
981-
self.assertTrue(nodes[10].target == ttnn.to_layout)
982-
self.assertTrue(nodes[11].target == ttnn.to_torch)
997+
self.assertTrue(nodes[1].target == ttnn.full)
998+
self.assertTrue(nodes[6].target == ttnn.sub)
999+
self.assertTrue(nodes[6].args[0].target == ttnn.to_layout)
1000+
self.assertTrue(nodes[6].args[1].target == ttnn.to_device)
1001+
self.assertTrue(nodes[6].args[1].args[0].target == ttnn.to_layout)
1002+
self.assertTrue(nodes[6].args[1].args[0].args[0].target == ttnn.from_torch)
1003+
self.assertTrue(nodes[7].target == ttnn.from_device)
1004+
self.assertTrue(nodes[8].target == ttnn.to_layout)
1005+
self.assertTrue(nodes[9].target == ttnn.to_torch)
9831006
# Check inference result
9841007
self.assertTrue(check_with_pcc(result_before, result_after, 0.9998))
9851008

@@ -1816,6 +1839,51 @@ def test_sigmoid(self):
18161839
# Check inference result
18171840
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))
18181841

1842+
def test_to_copy(self):
1843+
m = ToCopyModule()
1844+
input_shapes = m.input_shapes()
1845+
inputs = [torch.rand(shape) for shape in input_shapes]
1846+
result_before = m.forward(*inputs)
1847+
option = torch_ttnn.TorchTtnnOption(device=self.device)
1848+
option.gen_graphviz = True
1849+
# The compilation is lazy, so we need to run forward once to trigger the compilation
1850+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
1851+
result_after = m.forward(*inputs)
1852+
option._out_fx_graphs[0].print_tabular()
1853+
1854+
# Check the graph has be rewritten and contain ttnn ops
1855+
nodes = list(option._out_fx_graphs[0].nodes)
1856+
self.assertTrue(nodes[2].target == ttnn.as_tensor)
1857+
self.assertTrue(nodes[3].target == ttnn.from_device)
1858+
self.assertTrue(nodes[4].target == ttnn.to_layout)
1859+
self.assertTrue(nodes[5].target == ttnn.to_torch)
1860+
# Check inference result
1861+
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))
1862+
1863+
def test_to_copy_with_op_after(self):
1864+
m = ToCopyWithOpAfterModule()
1865+
input_shapes = m.input_shapes()
1866+
inputs = [torch.rand(shape) for shape in input_shapes]
1867+
result_before = m.forward(*inputs)
1868+
option = torch_ttnn.TorchTtnnOption(device=self.device)
1869+
option.gen_graphviz = True
1870+
# The compilation is lazy, so we need to run forward once to trigger the compilation
1871+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
1872+
result_after = m.forward(*inputs)
1873+
option._out_fx_graphs[0].print_tabular()
1874+
1875+
# Check the graph has be rewritten and contain ttnn ops
1876+
nodes = list(option._out_fx_graphs[0].nodes)
1877+
self.assertTrue(nodes[2].target == ttnn.as_tensor)
1878+
self.assertTrue(nodes[3].target == ttnn.add)
1879+
self.assertTrue(nodes[3].args[0].target == ttnn.as_tensor)
1880+
self.assertTrue(nodes[3].args[1].target == ttnn.as_tensor)
1881+
self.assertTrue(nodes[4].target == ttnn.from_device)
1882+
self.assertTrue(nodes[5].target == ttnn.to_layout)
1883+
self.assertTrue(nodes[6].target == ttnn.to_torch)
1884+
# Check inference result
1885+
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))
1886+
18191887

18201888
if __name__ == "__main__":
18211889
unittest.main()

tests/test_type_conversion.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

tools/run_transformers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ def run_model(
4444
]
4545

4646
if model.model_task in text_modules:
47-
tokenizer = AutoTokenizer.from_pretrained(model.model_name, padding_side="left")
47+
tokenizer = AutoTokenizer.from_pretrained(
48+
model.model_name, padding_side="left", torch_dtype=torch.bfloat16
49+
)
4850
elif model.model_task in vision_modules:
49-
image_processor = AutoImageProcessor.from_pretrained(model.model_name)
51+
image_processor = AutoImageProcessor.from_pretrained(
52+
model.model_name, torch_dtype=torch.bfloat16
53+
)
5054
else:
5155
raise ValueError(f"model task: {model.model_task} not supported.")
5256

53-
m = model.model_task.from_pretrained(model.model_name)
57+
m = model.model_task.from_pretrained(model.model_name, torch_dtype=torch.bfloat16)
5458

5559
if backward:
5660
try:

torch_ttnn/backend.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ def aten_backend(
3232

3333
gm = remove_clones_for_input_aliasing(gm)
3434

35-
# Change float types in dtype kwargs to bfloat16
36-
from .convert_type import convert_dtype_to_bfloat16, convert_float_to_bfloat16
37-
38-
gm = convert_float_to_bfloat16(gm)
39-
gm = convert_dtype_to_bfloat16(gm)
40-
4135
option: TorchTtnnOption = options["torch_ttnn_option"]
4236
torch.fx.graph._register_custom_builtin("ttnn_Specified_Device", "", option.device)
4337
torch.fx.graph._register_custom_builtin(

0 commit comments

Comments
 (0)