Skip to content

Commit

Permalink
Add a QuestionAnswering bert model for testing (#19)
Browse files Browse the repository at this point in the history
* Add matching transformers version as current tt-metal

* Remove the type conversion passes and directly call transformers to use bfloat16 when initializing the model

* Convert torch.Tensor.to to ttnn.as_tensor

* Use ttnn.full correctly instead of aten.full for certain cases

* Convert bert model to unittest

* Move transformers installation to dev

* Refactor model input and output print statement for test_bert
  • Loading branch information
kevinwuTT authored Jul 2, 2024
1 parent 4f8daaf commit 4da248b
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 202 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
pytest==7.2.2
pytest-timeout==2.2.0
pre-commit==3.0.4
transformers==4.38.0
86 changes: 86 additions & 0 deletions tests/test_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch_ttnn
import unittest
from torch_ttnn import ttnn
import collections

# Load model directly
from transformers import (
AutoTokenizer,
AutoModelForQuestionAnswering,
)


class TestBert(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close_device(self.device)

def test_bert(self):
# Download model from cloud
model_name = "phiyodr/bert-large-finetuned-squad2"
tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="left", torch_dtype=torch.bfloat16
)
m = AutoModelForQuestionAnswering.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
m.eval()

# Set up sample input
context = 'Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. "The prophet and founding hero of modern archaeology", Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art. '
question = "What discipline did Winkelmann create?"

inputs = tokenizer.encode_plus(
question,
context,
add_special_tokens=True,
return_tensors="pt",
max_length=256,
padding="max_length",
truncation=True,
)

# Run inference with the original model
with torch.no_grad():
outputs_before = m(**inputs)

# Helper function to decode output to human-readable text
def decode_output(outputs):
response_start = torch.argmax(outputs.start_logits)
response_end = torch.argmax(outputs.end_logits) + 1
response_tokens = inputs.input_ids[0, response_start:response_end]
return tokenizer.decode(response_tokens)

answer_before = decode_output(outputs_before)

# Compile model with ttnn backend
option = torch_ttnn.TorchTtnnOption(device=self.device)
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

# Run inference with the compiled model
with torch.no_grad():
outputs_after = m(**inputs)
option._out_fx_graphs[0].print_tabular()

answer_after = decode_output(outputs_after)

print(
f"""
model_name: {model_name}
input:
context: {context}
question: {question}
answer before: {answer_before}
answer after: {answer_after}
"""
)

# TODO: Add more checks for the compiled graph

# Check inference result
self.assertEqual(answer_before, answer_after)
108 changes: 88 additions & 20 deletions tests/test_more_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,29 @@ def input_shapes(self):
return [(4, 4)]


class ToCopyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to(torch.bfloat16)

def input_shapes(self):
return [(4, 4)]


class ToCopyWithOpAfterModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
to = x.to(torch.bfloat16)
return torch.add(to, to)

def input_shapes(self):
return [(4, 4)]


class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
Expand Down Expand Up @@ -894,18 +917,17 @@ def test_div_scalar_denom(self):

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[5].target == ttnn.reciprocal)
self.assertTrue(nodes[5].args[0].target == ttnn.to_device)
self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[9].target == ttnn.mul)
self.assertTrue(nodes[9].args[0].target == ttnn.to_device)
self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[9].args[1].target == ttnn.reciprocal)
self.assertTrue(nodes[10].target == ttnn.from_device)
self.assertTrue(nodes[11].target == ttnn.to_layout)
self.assertTrue(nodes[12].target == ttnn.to_torch)
self.assertTrue(nodes[1].target == ttnn.full)
self.assertTrue(nodes[3].target == ttnn.reciprocal)
self.assertTrue(nodes[3].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[7].target == ttnn.mul)
self.assertTrue(nodes[7].args[0].target == ttnn.to_device)
self.assertTrue(nodes[7].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[7].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[7].args[1].target == ttnn.reciprocal)
self.assertTrue(nodes[8].target == ttnn.from_device)
self.assertTrue(nodes[9].target == ttnn.to_layout)
self.assertTrue(nodes[10].target == ttnn.to_torch)
# Check inference result
self.assertTrue(check_with_pcc(result_before, result_after))

Expand Down Expand Up @@ -972,14 +994,15 @@ def test_rsub_scalar(self):

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
# self.aseertTrue(nodes[1].target == ttnn.full)
self.assertTrue(nodes[8].target == ttnn.sub)
self.assertTrue(nodes[8].args[0].target == ttnn.to_device)
self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[9].target == ttnn.from_device)
self.assertTrue(nodes[10].target == ttnn.to_layout)
self.assertTrue(nodes[11].target == ttnn.to_torch)
self.assertTrue(nodes[1].target == ttnn.full)
self.assertTrue(nodes[6].target == ttnn.sub)
self.assertTrue(nodes[6].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[6].args[1].target == ttnn.to_device)
self.assertTrue(nodes[6].args[1].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[6].args[1].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[7].target == ttnn.from_device)
self.assertTrue(nodes[8].target == ttnn.to_layout)
self.assertTrue(nodes[9].target == ttnn.to_torch)
# Check inference result
self.assertTrue(check_with_pcc(result_before, result_after, 0.9998))

Expand Down Expand Up @@ -1816,6 +1839,51 @@ def test_sigmoid(self):
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_to_copy(self):
m = ToCopyModule()
input_shapes = m.input_shapes()
inputs = [torch.rand(shape) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[2].target == ttnn.as_tensor)
self.assertTrue(nodes[3].target == ttnn.from_device)
self.assertTrue(nodes[4].target == ttnn.to_layout)
self.assertTrue(nodes[5].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_to_copy_with_op_after(self):
m = ToCopyWithOpAfterModule()
input_shapes = m.input_shapes()
inputs = [torch.rand(shape) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[2].target == ttnn.as_tensor)
self.assertTrue(nodes[3].target == ttnn.add)
self.assertTrue(nodes[3].args[0].target == ttnn.as_tensor)
self.assertTrue(nodes[3].args[1].target == ttnn.as_tensor)
self.assertTrue(nodes[4].target == ttnn.from_device)
self.assertTrue(nodes[5].target == ttnn.to_layout)
self.assertTrue(nodes[6].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))


if __name__ == "__main__":
unittest.main()
89 changes: 0 additions & 89 deletions tests/test_type_conversion.py

This file was deleted.

10 changes: 7 additions & 3 deletions tools/run_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def run_model(
]

if model.model_task in text_modules:
tokenizer = AutoTokenizer.from_pretrained(model.model_name, padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(
model.model_name, padding_side="left", torch_dtype=torch.bfloat16
)
elif model.model_task in vision_modules:
image_processor = AutoImageProcessor.from_pretrained(model.model_name)
image_processor = AutoImageProcessor.from_pretrained(
model.model_name, torch_dtype=torch.bfloat16
)
else:
raise ValueError(f"model task: {model.model_task} not supported.")

m = model.model_task.from_pretrained(model.model_name)
m = model.model_task.from_pretrained(model.model_name, torch_dtype=torch.bfloat16)

if backward:
try:
Expand Down
6 changes: 0 additions & 6 deletions torch_ttnn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ def aten_backend(

gm = remove_clones_for_input_aliasing(gm)

# Change float types in dtype kwargs to bfloat16
from .convert_type import convert_dtype_to_bfloat16, convert_float_to_bfloat16

gm = convert_float_to_bfloat16(gm)
gm = convert_dtype_to_bfloat16(gm)

option: TorchTtnnOption = options["torch_ttnn_option"]
torch.fx.graph._register_custom_builtin("ttnn_Specified_Device", "", option.device)
torch.fx.graph._register_custom_builtin(
Expand Down
Loading

0 comments on commit 4da248b

Please sign in to comment.