Skip to content

Commit d183cc8

Browse files
committed
Integrate custom TTNN extension for Pytorch
1 parent f4b0287 commit d183cc8

File tree

5 files changed

+1017
-1
lines changed

5 files changed

+1017
-1
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torchvision import transforms, datasets
5+
from torch.utils.data import DataLoader
6+
import ttnn
7+
import torch_ttnn
8+
from torch_ttnn.cpp_extension.custom_device_mode import ttnn_module, enable_ttnn_device
9+
import pytest
10+
11+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
12+
13+
import logging
14+
import sys
15+
16+
@pytest.mark.parametrize(
17+
"input_shape",
18+
((32, 1, 3, 3), (32,)),
19+
)
20+
def test_cpp_extension(device, input_shape):
21+
torch.utils.rename_privateuse1_backend('ttnn')
22+
23+
# in pytest the device has already been initialized before this call
24+
# so instead we can wrap this around the custom device
25+
ttnn_device = ttnn_module.custom_device_from_ttnn(device)
26+
27+
logging.info("Creating bfloat tensor from -1 to 1")
28+
torch_tensor = torch.empty(input_shape, dtype = torch.bfloat16).uniform_(-1, 1)
29+
print(torch_tensor)
30+
torch_tensor_abs = torch.abs(torch_tensor)
31+
print(torch_tensor_abs)
32+
33+
logging.info("Transferring to ttnn")
34+
torch_ttnn_tensor = torch_tensor.to(ttnn_device)
35+
36+
logging.info("get underlying ttnn tensor")
37+
ttnn_tensor = ttnn_module.get_ttnn_tensor(torch_ttnn_tensor)
38+
39+
logging.info("Running abs on ttnn")
40+
ttnn_tensor = ttnn.abs(ttnn_tensor)
41+
42+
logging.info("calling to_torch")
43+
ttnn_to_torch = ttnn.to_torch(ttnn_tensor)
44+
print(ttnn_to_torch)
45+
46+
47+
assert torch.allclose(torch_tensor_abs, ttnn_to_torch, rtol=0.1, atol=0.1)
48+
49+
# logging.info("Closing device")
50+
# ttnn_module.close_custom_device(ttnn_device)
51+
52+
def test_bert_with_cpp_extension(device):
53+
model_name = "phiyodr/bert-large-finetuned-squad2"
54+
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", torch_dtype=torch.bfloat16)
55+
m = AutoModelForQuestionAnswering.from_pretrained(model_name, torch_dtype=torch.bfloat16)
56+
context = 'Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. "The prophet and founding hero of modern archaeology", Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art. '
57+
question = "What discipline did Winkelmann create?"
58+
inputs = tokenizer.encode_plus(
59+
question,
60+
context,
61+
add_special_tokens=True,
62+
return_tensors="pt",
63+
max_length=256,
64+
padding="max_length",
65+
truncation=True,
66+
)
67+
68+
option = torch_ttnn.TorchTtnnOption(
69+
device=device,
70+
gen_graphviz=False,
71+
run_mem_analysis=False,
72+
metrics_path=model_name,
73+
verbose=True,
74+
)
75+
76+
# custom device
77+
torch.utils.rename_privateuse1_backend('ttnn')
78+
ttnn_device = ttnn_module.custom_device_from_ttnn(device)
79+
80+
# clone input_ids on cpu since this the data transfer is somehow inplace?
81+
input_ids = inputs.input_ids.clone()
82+
83+
inputs = inputs.to(ttnn_device)
84+
# modules are inplace, tensors are not
85+
m.to(ttnn_device)
86+
87+
model = torch.compile(m, backend=torch_ttnn.backend, options=option)
88+
outputs = model(**inputs)
89+
90+
# Helper function to decode output to human-readable text
91+
def decode_output(outputs):
92+
response_start = torch.argmax(outputs.start_logits)
93+
response_end = torch.argmax(outputs.end_logits) + 1
94+
response_tokens = input_ids[0, response_start:response_end]
95+
return tokenizer.decode(response_tokens)
96+
97+
print("finished:")
98+
print(outputs)
99+
answer = decode_output(outputs)
100+
101+
print(
102+
f"""
103+
model_name: {model_name}
104+
input:
105+
context: {context}
106+
question: {question}
107+
answer: {answer}
108+
"""
109+
)
110+
111+
# adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py
112+
class MnistModel(torch.nn.Module):
113+
def __init__(self):
114+
super(MnistModel, self).__init__()
115+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
116+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
117+
self.dropout1 = nn.Dropout(0.25)
118+
self.dropout2 = nn.Dropout(0.5)
119+
self.fc1 = nn.Linear(9216, 128)
120+
self.fc2 = nn.Linear(128, 10)
121+
122+
def forward(self, x):
123+
x = self.conv1(x)
124+
x = F.relu(x)
125+
x = self.conv2(x)
126+
x = F.relu(x)
127+
x = F.max_pool2d(x, 2)
128+
x = self.dropout1(x)
129+
x = torch.flatten(x, 1)
130+
x = self.fc1(x)
131+
x = F.relu(x)
132+
x = self.dropout2(x)
133+
x = self.fc2(x)
134+
x = F.log_softmax(x, dim=1)
135+
return x
136+
137+
def test_mnist_with_cpp_extension(device):
138+
model_name = "Mnist"
139+
transform = transforms.Compose([transforms.ToTensor()])
140+
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
141+
dataloader = DataLoader(test_dataset, batch_size=1)
142+
test_input, _ = next(iter(dataloader))
143+
test_input = test_input.to(torch.bfloat16)
144+
145+
# Copy weights and biases to ttnn
146+
torch.utils.rename_privateuse1_backend('ttnn')
147+
ttnn_device = ttnn_module.custom_device_from_ttnn(device)
148+
149+
150+
151+
option = torch_ttnn.TorchTtnnOption(
152+
device=device,
153+
gen_graphviz=False,
154+
run_mem_analysis=False,
155+
metrics_path=model_name,
156+
verbose=True,
157+
)
158+
159+
model = MnistModel()
160+
model = model.to(torch.bfloat16)
161+
test_input = test_input.to(ttnn_device)
162+
model.to(ttnn_device)
163+
164+
model = torch.compile(model, backend=torch_ttnn.backend, options=option)
165+
results = model(test_input)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#pragma once
2+
3+
#include <ATen/OpaqueTensorImpl.h>
4+
#include "ttnn/tensor/tensor.hpp"
5+
#include <iostream>
6+
#include <string.h>
7+
8+
template <typename Arg, typename... Args>
9+
void doPrint(std::ostream& out, const std::string_view& filename, int lineno, const std::string_view& fn, Arg&& arg, Args&&... args)
10+
{
11+
out << std::format("{}({})({}): ", filename, lineno, fn);
12+
out << std::forward<Arg>(arg);
13+
((out << std::forward<Args>(args)), ...);
14+
out << std::endl;
15+
}
16+
#define LOGGING(...) doPrint(std::cout, __FILE_NAME__, __LINE__, __FUNCTION__, __VA_ARGS__)
17+
18+
namespace at {
19+
20+
struct TtnnTensorImpl : public TensorImpl {
21+
TtnnTensorImpl(
22+
at::DispatchKeySet key_set,
23+
const caffe2::TypeMeta data_type,
24+
c10::Device device,
25+
ttnn::Tensor& ttnn_tensor,
26+
c10::intrusive_ptr<c10::StorageImpl> storage) : TensorImpl(key_set, data_type, device), ttnn_tensor_(ttnn_tensor), ttnn_tensor_string_(ttnn_tensor.write_to_string()) {
27+
storage_ = std::move(storage);
28+
auto view = ttnn_tensor_.get_logical_shape().view();
29+
std::vector<int64_t> view_int64;
30+
std::copy(view.begin(), view.end(), std::back_inserter(view_int64));
31+
IntArrayRef int_array_ref(&(*view_int64.begin()), &(*view_int64.end()));
32+
sizes_and_strides_.set_sizes(int_array_ref);
33+
}
34+
35+
TtnnTensorImpl(
36+
at::DispatchKeySet key_set,
37+
const caffe2::TypeMeta data_type,
38+
c10::Device device,
39+
const ttnn::Tensor& ttnn_tensor,
40+
const Storage& storage) : TensorImpl(key_set, data_type, device), ttnn_tensor_(ttnn_tensor), ttnn_tensor_string_(ttnn_tensor.write_to_string()) {
41+
storage_ = std::move(storage);
42+
auto view = ttnn_tensor_.get_logical_shape().view();
43+
std::vector<int64_t> view_int64;
44+
std::copy(view.begin(), view.end(), std::back_inserter(view_int64));
45+
IntArrayRef int_array_ref(&(*view_int64.begin()), &(*view_int64.end()));
46+
sizes_and_strides_.set_sizes(int_array_ref);
47+
}
48+
49+
void set_sizes_and_strides(const IntArrayRef& int_array_ref) {
50+
sizes_and_strides_.set_sizes(int_array_ref);
51+
}
52+
53+
void set_sizes_and_strides_as(const at::Tensor& the_template) {
54+
sizes_and_strides_.set_sizes(the_template.sizes());
55+
}
56+
57+
ttnn::Tensor get_ttnn_tensor() {
58+
// LOGGING(ttnn_tensor_string_);
59+
LOGGING(ttnn_tensor_.write_to_string());
60+
return ttnn_tensor_;
61+
}
62+
63+
void set_ttnn_tensor(const ttnn::Tensor& tensor) {
64+
ttnn_tensor_ = tensor;
65+
}
66+
67+
/**
68+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
69+
*
70+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
71+
* see NOTE [ TensorImpl Shallow-Copying ].
72+
*/
73+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
74+
const c10::VariableVersion& version_counter,
75+
bool allow_tensor_metadata_change) const override {
76+
auto impl = c10::make_intrusive<TtnnTensorImpl>(
77+
key_set(),
78+
dtype(),
79+
device(),
80+
ttnn_tensor_,
81+
storage_);
82+
copy_tensor_metadata(
83+
/*src_opaque_impl=*/this,
84+
/*dest_opaque_impl=*/impl.get(),
85+
/*version_counter=*/version_counter,
86+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
87+
impl->refresh_numel();
88+
return impl;
89+
}
90+
91+
/**
92+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
93+
*
94+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
95+
* see NOTE [ TensorImpl Shallow-Copying ].
96+
*/
97+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
98+
c10::VariableVersion&& version_counter,
99+
bool allow_tensor_metadata_change) const override {
100+
auto impl = c10::make_intrusive<TtnnTensorImpl>(
101+
key_set(),
102+
dtype(),
103+
device(),
104+
ttnn_tensor_,
105+
storage_);
106+
copy_tensor_metadata(
107+
/*src_opaque_impl=*/this,
108+
/*dest_opaque_impl=*/impl.get(),
109+
/*version_counter=*/std::move(version_counter),
110+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
111+
impl->refresh_numel();
112+
return impl;
113+
}
114+
115+
/**
116+
* Shallow-copies data from another TensorImpl into this TensorImpl.
117+
*
118+
* For why this function doesn't check this TensorImpl's
119+
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
120+
*/
121+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
122+
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
123+
auto ttnn_impl =
124+
static_cast<const TtnnTensorImpl*>(impl.get());
125+
copy_tensor_metadata(
126+
/*src_impl=*/ttnn_impl,
127+
/*dest_impl=*/this,
128+
/*version_counter=*/version_counter(),
129+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
130+
refresh_numel();
131+
}
132+
133+
// protected:
134+
// static void copy_tensor_metadata(
135+
// const TtnnTensorImpl* src_impl,
136+
// TtnnTensorImpl* dest_impl,
137+
// const c10::VariableVersion& version_counter,
138+
// bool allow_tensor_metadata_change) {
139+
// TensorImpl::copy_tensor_metadata(
140+
// src_impl,
141+
// dest_impl,
142+
// version_counter,
143+
// allow_tensor_metadata_change);
144+
145+
// // TtnnTensorImpl-specific fields.
146+
// dest_impl->ttnn_tensor_ = src_impl->ttnn_tensor_;
147+
// dest_impl->ttnn_tensor_string_ = src_impl->ttnn_tensor_string_;
148+
// }
149+
150+
// static void copy_tensor_metadata(
151+
// const TtnnTensorImpl* src_impl,
152+
// TtnnTensorImpl* dest_impl,
153+
// c10::VariableVersion&& version_counter,
154+
// bool allow_tensor_metadata_change) {
155+
// TensorImpl::copy_tensor_metadata(
156+
// src_impl,
157+
// dest_impl,
158+
// std::move(version_counter),
159+
// allow_tensor_metadata_change);
160+
161+
// // TtnnTensorImpl-specific fields.
162+
// dest_impl->ttnn_tensor_ = src_impl->ttnn_tensor_;
163+
// dest_impl->ttnn_tensor_string_ = src_impl->ttnn_tensor_string_;
164+
// }
165+
166+
private:
167+
ttnn::Tensor ttnn_tensor_;
168+
std::string ttnn_tensor_string_;
169+
};
170+
171+
} // namespace at

0 commit comments

Comments
 (0)