7
7
import torch_ttnn
8
8
from torch_ttnn .cpp_extension .custom_device_mode import ttnn_module , enable_ttnn_device
9
9
import pytest
10
+ import time
10
11
11
12
from transformers import AutoTokenizer , AutoModelForQuestionAnswering
12
13
13
14
import logging
14
15
import sys
15
16
17
+
16
18
@pytest .mark .parametrize (
17
19
"input_shape" ,
18
- ((32 , 1 , 3 , 3 ), (32 ,)),
20
+ ((32 , 1 , 3 , 3 ), (1 , 32 )),
21
+ )
22
+ @pytest .mark .parametrize (
23
+ "dtype" ,
24
+ (torch .bfloat16 , torch .int32 ),
19
25
)
20
- def test_cpp_extension (device , input_shape ):
21
- torch .utils .rename_privateuse1_backend (' ttnn' )
26
+ def test_cpp_extension (device , input_shape , dtype ):
27
+ torch .utils .rename_privateuse1_backend (" ttnn" )
22
28
23
29
# in pytest the device has already been initialized before this call
24
30
# so instead we can wrap this around the custom device
25
31
ttnn_device = ttnn_module .custom_device_from_ttnn (device )
26
32
27
33
logging .info ("Creating bfloat tensor from -1 to 1" )
28
- torch_tensor = torch .empty (input_shape , dtype = torch .bfloat16 ).uniform_ (- 1 , 1 )
34
+ if dtype == torch .bfloat16 :
35
+ torch_tensor = torch .empty (input_shape , dtype = dtype ).uniform_ (- 1 , 1 )
36
+ elif dtype == torch .int32 :
37
+ torch_tensor = torch .randint (- 1000 , 1000 , input_shape )
38
+ torch_tensor = torch_tensor .to (torch .int32 )
39
+ else :
40
+ raise Exception (f"{ dtype } not being tested at this time" )
29
41
print (torch_tensor )
30
- torch_tensor_abs = torch .abs (torch_tensor )
31
- print (torch_tensor_abs )
32
42
33
43
logging .info ("Transferring to ttnn" )
34
44
torch_ttnn_tensor = torch_tensor .to (ttnn_device )
35
45
36
- logging .info ("get underlying ttnn tensor" )
46
+ logging .info ("Get underlying ttnn tensor" )
37
47
ttnn_tensor = ttnn_module .get_ttnn_tensor (torch_ttnn_tensor )
38
48
39
- logging .info ("Running abs on ttnn" )
40
- ttnn_tensor = ttnn .abs (ttnn_tensor )
49
+ # Compare output of abs op for bfloat16 dtype since ttnn.abs does not support int
50
+ if dtype == torch .bfloat16 :
51
+ torch_out = torch .abs (torch_tensor )
52
+ print (torch_out )
53
+
54
+ logging .info ("Running abs on ttnn" )
55
+ ttnn_tensor = ttnn .abs (ttnn_tensor )
56
+ elif dtype == torch .int32 :
57
+ torch_out = torch_tensor
58
+ else :
59
+ raise Exception (f"{ dtype } not being tested at this time" )
41
60
42
61
logging .info ("calling to_torch" )
43
62
ttnn_to_torch = ttnn .to_torch (ttnn_tensor )
63
+
44
64
print (ttnn_to_torch )
45
-
46
-
47
- assert torch .allclose (torch_tensor_abs , ttnn_to_torch , rtol = 0.1 , atol = 0.1 )
48
65
49
- # logging.info("Closing device" )
50
- # ttnn_module.close_custom_device(ttnn_device)
66
+ assert torch . allclose ( torch_out , ttnn_to_torch , rtol = 0.1 , atol = 0.1 )
67
+
51
68
52
69
def test_bert_with_cpp_extension (device ):
53
70
model_name = "phiyodr/bert-large-finetuned-squad2"
@@ -66,34 +83,45 @@ def test_bert_with_cpp_extension(device):
66
83
)
67
84
68
85
option = torch_ttnn .TorchTtnnOption (
69
- device = device ,
70
- gen_graphviz = False ,
71
- run_mem_analysis = False ,
72
- metrics_path = model_name ,
73
- verbose = True ,
74
- )
86
+ device = device ,
87
+ gen_graphviz = False ,
88
+ run_mem_analysis = False ,
89
+ metrics_path = model_name ,
90
+ verbose = True ,
91
+ )
75
92
76
93
# custom device
77
- torch .utils .rename_privateuse1_backend (' ttnn' )
94
+ torch .utils .rename_privateuse1_backend (" ttnn" )
78
95
ttnn_device = ttnn_module .custom_device_from_ttnn (device )
79
-
96
+
80
97
# clone input_ids on cpu since this the data transfer is somehow inplace?
81
98
input_ids = inputs .input_ids .clone ()
82
-
83
- inputs = inputs .to (ttnn_device )
84
- # modules are inplace, tensors are not
85
- m .to (ttnn_device )
86
99
87
- model = torch .compile (m , backend = torch_ttnn .backend , options = option )
88
- outputs = model (** inputs )
89
-
90
100
# Helper function to decode output to human-readable text
91
101
def decode_output (outputs ):
92
102
response_start = torch .argmax (outputs .start_logits )
93
103
response_end = torch .argmax (outputs .end_logits ) + 1
94
104
response_tokens = input_ids [0 , response_start :response_end ]
95
105
return tokenizer .decode (response_tokens )
96
106
107
+ # comment out these to disable cpp extension
108
+ start_to = time .perf_counter () * 1000
109
+ inputs = inputs .to (ttnn_device )
110
+ # modules are inplace, tensors are not
111
+ m .to (ttnn_device )
112
+ end_to = time .perf_counter () * 1000
113
+ print (f"to: { end_to - start_to } (ms)" )
114
+
115
+ model = torch .compile (m , backend = torch_ttnn .backend , options = option )
116
+
117
+ for idx in range (5 ):
118
+ start = time .perf_counter () * 1000
119
+ # Don't need to reset options if inputs don't change because of cache
120
+ outputs = model (** inputs )
121
+ end = time .perf_counter () * 1000
122
+ run_time = end - start
123
+ print (f"iter { idx } : { run_time } (ms)" )
124
+
97
125
print ("finished:" )
98
126
print (outputs )
99
127
answer = decode_output (outputs )
@@ -108,9 +136,10 @@ def decode_output(outputs):
108
136
"""
109
137
)
110
138
139
+
111
140
# adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py
112
141
class MnistModel (torch .nn .Module ):
113
- def __init__ (self ):
142
+ def __init__ (self ):
114
143
super (MnistModel , self ).__init__ ()
115
144
self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
116
145
self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
@@ -133,8 +162,9 @@ def forward(self, x):
133
162
x = self .fc2 (x )
134
163
x = F .log_softmax (x , dim = 1 )
135
164
return x
136
-
137
- def test_mnist_with_cpp_extension (device ):
165
+
166
+ @pytest .mark .skip (reason = "Does not support conv for now" )
167
+ def test_mnist_with_cpp_extension (device ):
138
168
model_name = "Mnist"
139
169
transform = transforms .Compose ([transforms .ToTensor ()])
140
170
test_dataset = datasets .MNIST (root = "./data" , train = False , transform = transform , download = True )
@@ -143,23 +173,21 @@ def test_mnist_with_cpp_extension(device):
143
173
test_input = test_input .to (torch .bfloat16 )
144
174
145
175
# Copy weights and biases to ttnn
146
- torch .utils .rename_privateuse1_backend (' ttnn' )
176
+ torch .utils .rename_privateuse1_backend (" ttnn" )
147
177
ttnn_device = ttnn_module .custom_device_from_ttnn (device )
148
-
149
178
150
-
151
179
option = torch_ttnn .TorchTtnnOption (
152
- device = device ,
153
- gen_graphviz = False ,
154
- run_mem_analysis = False ,
155
- metrics_path = model_name ,
156
- verbose = True ,
157
- )
180
+ device = device ,
181
+ gen_graphviz = False ,
182
+ run_mem_analysis = False ,
183
+ metrics_path = model_name ,
184
+ verbose = True ,
185
+ )
158
186
159
187
model = MnistModel ()
160
188
model = model .to (torch .bfloat16 )
161
189
test_input = test_input .to (ttnn_device )
162
190
model .to (ttnn_device )
163
-
191
+
164
192
model = torch .compile (model , backend = torch_ttnn .backend , options = option )
165
- results = model (test_input )
193
+ results = model (test_input )
0 commit comments