Skip to content

Commit 237ca12

Browse files
committed
Make changes to work on BERT
1 parent d183cc8 commit 237ca12

File tree

7 files changed

+301
-214
lines changed

7 files changed

+301
-214
lines changed

tests/cpp_extension/test_cpp_extension.py

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,64 @@
77
import torch_ttnn
88
from torch_ttnn.cpp_extension.custom_device_mode import ttnn_module, enable_ttnn_device
99
import pytest
10+
import time
1011

1112
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
1213

1314
import logging
1415
import sys
1516

17+
1618
@pytest.mark.parametrize(
1719
"input_shape",
18-
((32, 1, 3, 3), (32,)),
20+
((32, 1, 3, 3), (1, 32)),
21+
)
22+
@pytest.mark.parametrize(
23+
"dtype",
24+
(torch.bfloat16, torch.int32),
1925
)
20-
def test_cpp_extension(device, input_shape):
21-
torch.utils.rename_privateuse1_backend('ttnn')
26+
def test_cpp_extension(device, input_shape, dtype):
27+
torch.utils.rename_privateuse1_backend("ttnn")
2228

2329
# in pytest the device has already been initialized before this call
2430
# so instead we can wrap this around the custom device
2531
ttnn_device = ttnn_module.custom_device_from_ttnn(device)
2632

2733
logging.info("Creating bfloat tensor from -1 to 1")
28-
torch_tensor = torch.empty(input_shape, dtype = torch.bfloat16).uniform_(-1, 1)
34+
if dtype == torch.bfloat16:
35+
torch_tensor = torch.empty(input_shape, dtype=dtype).uniform_(-1, 1)
36+
elif dtype == torch.int32:
37+
torch_tensor = torch.randint(-1000, 1000, input_shape)
38+
torch_tensor = torch_tensor.to(torch.int32)
39+
else:
40+
raise Exception(f"{dtype} not being tested at this time")
2941
print(torch_tensor)
30-
torch_tensor_abs = torch.abs(torch_tensor)
31-
print(torch_tensor_abs)
3242

3343
logging.info("Transferring to ttnn")
3444
torch_ttnn_tensor = torch_tensor.to(ttnn_device)
3545

36-
logging.info("get underlying ttnn tensor")
46+
logging.info("Get underlying ttnn tensor")
3747
ttnn_tensor = ttnn_module.get_ttnn_tensor(torch_ttnn_tensor)
3848

39-
logging.info("Running abs on ttnn")
40-
ttnn_tensor = ttnn.abs(ttnn_tensor)
49+
# Compare output of abs op for bfloat16 dtype since ttnn.abs does not support int
50+
if dtype == torch.bfloat16:
51+
torch_out = torch.abs(torch_tensor)
52+
print(torch_out)
53+
54+
logging.info("Running abs on ttnn")
55+
ttnn_tensor = ttnn.abs(ttnn_tensor)
56+
elif dtype == torch.int32:
57+
torch_out = torch_tensor
58+
else:
59+
raise Exception(f"{dtype} not being tested at this time")
4160

4261
logging.info("calling to_torch")
4362
ttnn_to_torch = ttnn.to_torch(ttnn_tensor)
63+
4464
print(ttnn_to_torch)
45-
46-
47-
assert torch.allclose(torch_tensor_abs, ttnn_to_torch, rtol=0.1, atol=0.1)
4865

49-
# logging.info("Closing device")
50-
# ttnn_module.close_custom_device(ttnn_device)
66+
assert torch.allclose(torch_out, ttnn_to_torch, rtol=0.1, atol=0.1)
67+
5168

5269
def test_bert_with_cpp_extension(device):
5370
model_name = "phiyodr/bert-large-finetuned-squad2"
@@ -66,34 +83,45 @@ def test_bert_with_cpp_extension(device):
6683
)
6784

6885
option = torch_ttnn.TorchTtnnOption(
69-
device=device,
70-
gen_graphviz=False,
71-
run_mem_analysis=False,
72-
metrics_path=model_name,
73-
verbose=True,
74-
)
86+
device=device,
87+
gen_graphviz=False,
88+
run_mem_analysis=False,
89+
metrics_path=model_name,
90+
verbose=True,
91+
)
7592

7693
# custom device
77-
torch.utils.rename_privateuse1_backend('ttnn')
94+
torch.utils.rename_privateuse1_backend("ttnn")
7895
ttnn_device = ttnn_module.custom_device_from_ttnn(device)
79-
96+
8097
# clone input_ids on cpu since this the data transfer is somehow inplace?
8198
input_ids = inputs.input_ids.clone()
82-
83-
inputs = inputs.to(ttnn_device)
84-
# modules are inplace, tensors are not
85-
m.to(ttnn_device)
8699

87-
model = torch.compile(m, backend=torch_ttnn.backend, options=option)
88-
outputs = model(**inputs)
89-
90100
# Helper function to decode output to human-readable text
91101
def decode_output(outputs):
92102
response_start = torch.argmax(outputs.start_logits)
93103
response_end = torch.argmax(outputs.end_logits) + 1
94104
response_tokens = input_ids[0, response_start:response_end]
95105
return tokenizer.decode(response_tokens)
96106

107+
# comment out these to disable cpp extension
108+
start_to = time.perf_counter() * 1000
109+
inputs = inputs.to(ttnn_device)
110+
# modules are inplace, tensors are not
111+
m.to(ttnn_device)
112+
end_to = time.perf_counter() * 1000
113+
print(f"to: {end_to - start_to} (ms)")
114+
115+
model = torch.compile(m, backend=torch_ttnn.backend, options=option)
116+
117+
for idx in range(5):
118+
start = time.perf_counter() * 1000
119+
# Don't need to reset options if inputs don't change because of cache
120+
outputs = model(**inputs)
121+
end = time.perf_counter() * 1000
122+
run_time = end - start
123+
print(f"iter {idx}: {run_time} (ms)")
124+
97125
print("finished:")
98126
print(outputs)
99127
answer = decode_output(outputs)
@@ -108,9 +136,10 @@ def decode_output(outputs):
108136
"""
109137
)
110138

139+
111140
# adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py
112141
class MnistModel(torch.nn.Module):
113-
def __init__(self):
142+
def __init__(self):
114143
super(MnistModel, self).__init__()
115144
self.conv1 = nn.Conv2d(1, 32, 3, 1)
116145
self.conv2 = nn.Conv2d(32, 64, 3, 1)
@@ -133,8 +162,9 @@ def forward(self, x):
133162
x = self.fc2(x)
134163
x = F.log_softmax(x, dim=1)
135164
return x
136-
137-
def test_mnist_with_cpp_extension(device):
165+
166+
@pytest.mark.skip(reason="Does not support conv for now")
167+
def test_mnist_with_cpp_extension(device):
138168
model_name = "Mnist"
139169
transform = transforms.Compose([transforms.ToTensor()])
140170
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
@@ -143,23 +173,21 @@ def test_mnist_with_cpp_extension(device):
143173
test_input = test_input.to(torch.bfloat16)
144174

145175
# Copy weights and biases to ttnn
146-
torch.utils.rename_privateuse1_backend('ttnn')
176+
torch.utils.rename_privateuse1_backend("ttnn")
147177
ttnn_device = ttnn_module.custom_device_from_ttnn(device)
148-
149178

150-
151179
option = torch_ttnn.TorchTtnnOption(
152-
device=device,
153-
gen_graphviz=False,
154-
run_mem_analysis=False,
155-
metrics_path=model_name,
156-
verbose=True,
157-
)
180+
device=device,
181+
gen_graphviz=False,
182+
run_mem_analysis=False,
183+
metrics_path=model_name,
184+
verbose=True,
185+
)
158186

159187
model = MnistModel()
160188
model = model.to(torch.bfloat16)
161189
test_input = test_input.to(ttnn_device)
162190
model.to(ttnn_device)
163-
191+
164192
model = torch.compile(model, backend=torch_ttnn.backend, options=option)
165-
results = model(test_input)
193+
results = model(test_input)

tests/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import re
55
from typing import List, Dict, Tuple
6+
from torch_ttnn.cpp_extension.custom_device_mode import ttnn_module, enable_ttnn_device
67

78

89
class ModelTester:
@@ -130,6 +131,9 @@ def test_model_eval(self, as_ttnn=False, option=None):
130131
model = self.set_model_eval(self.model)
131132
inputs = self.set_inputs_eval(self.inputs)
132133
if as_ttnn == True:
134+
torch.utils.rename_privateuse1_backend("ttnn")
135+
ttnn_device = ttnn_module.custom_device_from_ttnn(option.device)
136+
inputs = inputs.to(ttnn_device)
133137
model = self.compile_model(model, option)
134138
outputs = self.run_model(model, inputs)
135139
results = self.get_results_eval(model, inputs, outputs)

torch_ttnn/cpp_extension/TtnnOpaqueTensorImpl.h renamed to torch_ttnn/cpp_extension/TtnnTensorImpl.hpp

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,14 @@
11
#pragma once
22

3-
#include <ATen/OpaqueTensorImpl.h>
43
#include "ttnn/tensor/tensor.hpp"
4+
#include "extension_utils.hpp"
55
#include <iostream>
66
#include <string.h>
77

8-
template <typename Arg, typename... Args>
9-
void doPrint(std::ostream& out, const std::string_view& filename, int lineno, const std::string_view& fn, Arg&& arg, Args&&... args)
10-
{
11-
out << std::format("{}({})({}): ", filename, lineno, fn);
12-
out << std::forward<Arg>(arg);
13-
((out << std::forward<Args>(args)), ...);
14-
out << std::endl;
15-
}
16-
#define LOGGING(...) doPrint(std::cout, __FILE_NAME__, __LINE__, __FUNCTION__, __VA_ARGS__)
17-
188
namespace at {
199

2010
struct TtnnTensorImpl : public TensorImpl {
11+
// TODO: Only difference is the storage type, combine these two
2112
TtnnTensorImpl(
2213
at::DispatchKeySet key_set,
2314
const caffe2::TypeMeta data_type,
@@ -55,8 +46,7 @@ struct TtnnTensorImpl : public TensorImpl {
5546
}
5647

5748
ttnn::Tensor get_ttnn_tensor() {
58-
// LOGGING(ttnn_tensor_string_);
59-
LOGGING(ttnn_tensor_.write_to_string());
49+
LOGGING("");
6050
return ttnn_tensor_;
6151
}
6252

@@ -129,42 +119,9 @@ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
129119
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
130120
refresh_numel();
131121
}
132-
133-
// protected:
134-
// static void copy_tensor_metadata(
135-
// const TtnnTensorImpl* src_impl,
136-
// TtnnTensorImpl* dest_impl,
137-
// const c10::VariableVersion& version_counter,
138-
// bool allow_tensor_metadata_change) {
139-
// TensorImpl::copy_tensor_metadata(
140-
// src_impl,
141-
// dest_impl,
142-
// version_counter,
143-
// allow_tensor_metadata_change);
144-
145-
// // TtnnTensorImpl-specific fields.
146-
// dest_impl->ttnn_tensor_ = src_impl->ttnn_tensor_;
147-
// dest_impl->ttnn_tensor_string_ = src_impl->ttnn_tensor_string_;
148-
// }
149-
150-
// static void copy_tensor_metadata(
151-
// const TtnnTensorImpl* src_impl,
152-
// TtnnTensorImpl* dest_impl,
153-
// c10::VariableVersion&& version_counter,
154-
// bool allow_tensor_metadata_change) {
155-
// TensorImpl::copy_tensor_metadata(
156-
// src_impl,
157-
// dest_impl,
158-
// std::move(version_counter),
159-
// allow_tensor_metadata_change);
160-
161-
// // TtnnTensorImpl-specific fields.
162-
// dest_impl->ttnn_tensor_ = src_impl->ttnn_tensor_;
163-
// dest_impl->ttnn_tensor_string_ = src_impl->ttnn_tensor_string_;
164-
// }
165-
166122
private:
167123
ttnn::Tensor ttnn_tensor_;
124+
// TODO: Debug only, should probably remove as it might be costly
168125
std::string ttnn_tensor_string_;
169126
};
170127

torch_ttnn/cpp_extension/custom_device_mode.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import glob
77
import logging
88

9-
assert os.environ.get('TT_METAL_HOME') is not None
9+
assert os.environ.get("TT_METAL_HOME") is not None
1010
tt_metal_home = Path(os.environ["TT_METAL_HOME"])
1111

1212
cpmcache_pattern = Path(".cpmcache/**/include")
@@ -34,7 +34,7 @@
3434
tt_metal_home / Path("tt_metal/api"),
3535
tt_metal_home / Path("tt_metal/tt_stl"),
3636
tt_metal_home / Path("tt_metal/include/tt_metal/internal"),
37-
] + cpmcache_dirs
37+
] + cpmcache_dirs
3838
ttnn_include_paths = [str(p) for p in ttnn_include_paths]
3939

4040
# Load the C++ extension containing your custom kernels.
@@ -64,12 +64,13 @@
6464
str(working_directory / "open_registration_extension.cpp"),
6565
],
6666
extra_include_paths=[str(working_directory)] + ttnn_include_paths,
67-
extra_cflags=["-g", "-DFMT_HEADER_ONLY", '-std=c++20', '-stdlib=libc++'],
67+
extra_cflags=["-g", "-DFMT_HEADER_ONLY", "-std=c++20", "-stdlib=libc++"],
6868
extra_ldflags=tt_metal_lib_paths + tt_metal_libs,
6969
verbose=True,
7070
)
7171

72-
logging.info('Loaded custom extension.')
72+
logging.info("Loaded custom extension.")
73+
7374

7475
# The user will globally enable the below mode when calling this API
7576
def enable_ttnn_device():
@@ -78,6 +79,7 @@ def enable_ttnn_device():
7879
# If you want the mode to never be disabled, then this function shouldn't return anything.
7980
return m
8081

82+
8183
# This is a simple TorchFunctionMode class that:
8284
# (a) Intercepts all torch.* calls
8385
# (b) Checks for kwargs of the form `device="foo:i"`
@@ -87,14 +89,14 @@ class TtnnDeviceMode(TorchFunctionMode):
8789
def __torch_function__(self, func, types, args=(), kwargs=None):
8890
if kwargs is None:
8991
kwargs = {}
90-
if 'device' in kwargs and 'ttnn' in kwargs['device']:
91-
device_and_idx = kwargs['device'].split(':')
92+
if "device" in kwargs and "ttnn" in kwargs["device"]:
93+
device_and_idx = kwargs["device"].split(":")
9294
if len(device_and_idx) == 1:
9395
# Case 1: No index specified
94-
kwargs['device'] = ttnn_module.custom_device()
96+
kwargs["device"] = ttnn_module.custom_device()
9597
else:
9698
# Case 2: The user specified a device index.
9799
device_idx = int(device_and_idx[1])
98-
kwargs['device'] = ttnn_module.custom_device(device_idx)
100+
kwargs["device"] = ttnn_module.custom_device(device_idx)
99101
with torch._C.DisableTorchFunction():
100-
return func(*args, **kwargs)
102+
return func(*args, **kwargs)

0 commit comments

Comments
 (0)