Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a QuestionAnswering bert model for testing #19

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
pytest==7.2.2
pytest-timeout==2.2.0
pre-commit==3.0.4
transformers==4.38.0
86 changes: 86 additions & 0 deletions tests/test_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch_ttnn
import unittest
from torch_ttnn import ttnn
import collections

# Load model directly
from transformers import (
AutoTokenizer,
AutoModelForQuestionAnswering,
)


class TestBert(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close_device(self.device)

def test_bert(self):
# Download model from cloud
model_name = "phiyodr/bert-large-finetuned-squad2"
tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="left", torch_dtype=torch.bfloat16
)
m = AutoModelForQuestionAnswering.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
m.eval()

# Set up sample input
context = 'Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. "The prophet and founding hero of modern archaeology", Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art. '
question = "What discipline did Winkelmann create?"

inputs = tokenizer.encode_plus(
question,
context,
add_special_tokens=True,
return_tensors="pt",
max_length=256,
padding="max_length",
truncation=True,
)

# Run inference with the original model
with torch.no_grad():
outputs_before = m(**inputs)

# Helper function to decode output to human-readable text
def decode_output(outputs):
response_start = torch.argmax(outputs.start_logits)
response_end = torch.argmax(outputs.end_logits) + 1
response_tokens = inputs.input_ids[0, response_start:response_end]
return tokenizer.decode(response_tokens)

answer_before = decode_output(outputs_before)

# Compile model with ttnn backend
option = torch_ttnn.TorchTtnnOption(device=self.device)
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

# Run inference with the compiled model
with torch.no_grad():
outputs_after = m(**inputs)
option._out_fx_graphs[0].print_tabular()

answer_after = decode_output(outputs_after)

print(
f"""
model_name: {model_name}
input:
context: {context}
question: {question}
answer before: {answer_before}
answer after: {answer_after}
"""
)

# TODO: Add more checks for the compiled graph

# Check inference result
self.assertEqual(answer_before, answer_after)
108 changes: 88 additions & 20 deletions tests/test_more_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,29 @@ def input_shapes(self):
return [(4, 4)]


class ToCopyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to(torch.bfloat16)

def input_shapes(self):
return [(4, 4)]


class ToCopyWithOpAfterModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
to = x.to(torch.bfloat16)
return torch.add(to, to)

def input_shapes(self):
return [(4, 4)]


class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
Expand Down Expand Up @@ -894,18 +917,17 @@ def test_div_scalar_denom(self):

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[5].target == ttnn.reciprocal)
self.assertTrue(nodes[5].args[0].target == ttnn.to_device)
self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[9].target == ttnn.mul)
self.assertTrue(nodes[9].args[0].target == ttnn.to_device)
self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[9].args[1].target == ttnn.reciprocal)
self.assertTrue(nodes[10].target == ttnn.from_device)
self.assertTrue(nodes[11].target == ttnn.to_layout)
self.assertTrue(nodes[12].target == ttnn.to_torch)
self.assertTrue(nodes[1].target == ttnn.full)
self.assertTrue(nodes[3].target == ttnn.reciprocal)
self.assertTrue(nodes[3].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[7].target == ttnn.mul)
self.assertTrue(nodes[7].args[0].target == ttnn.to_device)
self.assertTrue(nodes[7].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[7].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[7].args[1].target == ttnn.reciprocal)
self.assertTrue(nodes[8].target == ttnn.from_device)
self.assertTrue(nodes[9].target == ttnn.to_layout)
self.assertTrue(nodes[10].target == ttnn.to_torch)
# Check inference result
self.assertTrue(check_with_pcc(result_before, result_after))

Expand Down Expand Up @@ -972,14 +994,15 @@ def test_rsub_scalar(self):

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
# self.aseertTrue(nodes[1].target == ttnn.full)
self.assertTrue(nodes[8].target == ttnn.sub)
self.assertTrue(nodes[8].args[0].target == ttnn.to_device)
self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[9].target == ttnn.from_device)
self.assertTrue(nodes[10].target == ttnn.to_layout)
self.assertTrue(nodes[11].target == ttnn.to_torch)
self.assertTrue(nodes[1].target == ttnn.full)
self.assertTrue(nodes[6].target == ttnn.sub)
self.assertTrue(nodes[6].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[6].args[1].target == ttnn.to_device)
self.assertTrue(nodes[6].args[1].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[6].args[1].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[7].target == ttnn.from_device)
self.assertTrue(nodes[8].target == ttnn.to_layout)
self.assertTrue(nodes[9].target == ttnn.to_torch)
# Check inference result
self.assertTrue(check_with_pcc(result_before, result_after, 0.9998))

Expand Down Expand Up @@ -1816,6 +1839,51 @@ def test_sigmoid(self):
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_to_copy(self):
m = ToCopyModule()
input_shapes = m.input_shapes()
inputs = [torch.rand(shape) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[2].target == ttnn.as_tensor)
self.assertTrue(nodes[3].target == ttnn.from_device)
self.assertTrue(nodes[4].target == ttnn.to_layout)
self.assertTrue(nodes[5].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_to_copy_with_op_after(self):
m = ToCopyWithOpAfterModule()
input_shapes = m.input_shapes()
inputs = [torch.rand(shape) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[2].target == ttnn.as_tensor)
self.assertTrue(nodes[3].target == ttnn.add)
self.assertTrue(nodes[3].args[0].target == ttnn.as_tensor)
self.assertTrue(nodes[3].args[1].target == ttnn.as_tensor)
self.assertTrue(nodes[4].target == ttnn.from_device)
self.assertTrue(nodes[5].target == ttnn.to_layout)
self.assertTrue(nodes[6].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))


if __name__ == "__main__":
unittest.main()
89 changes: 0 additions & 89 deletions tests/test_type_conversion.py

This file was deleted.

10 changes: 7 additions & 3 deletions tools/run_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def run_model(
]

if model.model_task in text_modules:
tokenizer = AutoTokenizer.from_pretrained(model.model_name, padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(
model.model_name, padding_side="left", torch_dtype=torch.bfloat16
)
elif model.model_task in vision_modules:
image_processor = AutoImageProcessor.from_pretrained(model.model_name)
image_processor = AutoImageProcessor.from_pretrained(
model.model_name, torch_dtype=torch.bfloat16
)
else:
raise ValueError(f"model task: {model.model_task} not supported.")

m = model.model_task.from_pretrained(model.model_name)
m = model.model_task.from_pretrained(model.model_name, torch_dtype=torch.bfloat16)

if backward:
try:
Expand Down
6 changes: 0 additions & 6 deletions torch_ttnn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ def aten_backend(

gm = remove_clones_for_input_aliasing(gm)

# Change float types in dtype kwargs to bfloat16
from .convert_type import convert_dtype_to_bfloat16, convert_float_to_bfloat16

gm = convert_float_to_bfloat16(gm)
gm = convert_dtype_to_bfloat16(gm)

option: TorchTtnnOption = options["torch_ttnn_option"]
torch.fx.graph._register_custom_builtin("ttnn_Specified_Device", "", option.device)
torch.fx.graph._register_custom_builtin(
Expand Down
Loading