@@ -52,6 +52,7 @@ def forward(self, x):
52
52
53
53
deser_trt_module = torchtrt .load (trt_ep_path ).module ()
54
54
# Check Pyt and TRT exported program outputs
55
+ model .cuda ()
55
56
cos_sim = cosine_similarity (model (input ), trt_module (input )[0 ])
56
57
assertions .assertTrue (
57
58
cos_sim > COSINE_THRESHOLD ,
@@ -106,6 +107,7 @@ def forward(self, x):
106
107
107
108
deser_trt_module = torchtrt .load (trt_ep_path ).module ()
108
109
# Check Pyt and TRT exported program outputs
110
+ model .cuda ()
109
111
outputs_pyt = model (input )
110
112
outputs_trt = trt_module (input )
111
113
for idx in range (len (outputs_pyt )):
@@ -162,8 +164,9 @@ def forward(self, x):
162
164
trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
163
165
torchtrt .save (trt_module , trt_ep_path )
164
166
165
- deser_trt_module = torchtrt .load (trt_ep_path ).module ()
167
+ deser_trt_module = torchtrt .load (trt_ep_path ).module (). cuda ()
166
168
# Check Pyt and TRT exported program outputs
169
+ model .cuda ()
167
170
outputs_pyt = model (input )
168
171
outputs_trt = trt_module (input )
169
172
for idx in range (len (outputs_pyt )):
@@ -173,7 +176,7 @@ def forward(self, x):
173
176
msg = f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
174
177
)
175
178
176
- # # Check Pyt and deserialized TRT exported program outputs
179
+ # Check Pyt and deserialized TRT exported program outputs
177
180
outputs_trt_deser = deser_trt_module (input )
178
181
for idx in range (len (outputs_pyt )):
179
182
cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
@@ -224,6 +227,7 @@ def forward(self, x):
224
227
torchtrt .save (trt_module , trt_ep_path )
225
228
226
229
deser_trt_module = torchtrt .load (trt_ep_path ).module ()
230
+ model .cuda ()
227
231
outputs_pyt = model (input )
228
232
outputs_trt = trt_module (input )
229
233
for idx in range (len (outputs_pyt )):
@@ -267,6 +271,7 @@ def test_resnet18(ir):
267
271
torchtrt .save (trt_module , trt_ep_path )
268
272
269
273
deser_trt_module = torchtrt .load (trt_ep_path ).module ()
274
+ model .cuda ()
270
275
outputs_pyt = model (input )
271
276
outputs_trt = trt_module (input )
272
277
cos_sim = cosine_similarity (outputs_pyt , outputs_trt [0 ])
@@ -312,6 +317,7 @@ def test_resnet18_dynamic(ir):
312
317
torchtrt .save (trt_module , trt_ep_path )
313
318
# TODO: Enable this serialization issues are fixed
314
319
# deser_trt_module = torchtrt.load(trt_ep_path).module()
320
+ model .cuda ()
315
321
outputs_pyt = model (input )
316
322
outputs_trt = trt_module (input )
317
323
cos_sim = cosine_similarity (outputs_pyt , outputs_trt [0 ])
@@ -362,6 +368,7 @@ def forward(self, x):
362
368
torchtrt .save (trt_module , trt_ep_path )
363
369
364
370
deser_trt_module = torchtrt .load (trt_ep_path ).module ()
371
+ model .cuda ()
365
372
outputs_pyt = model (input )
366
373
outputs_trt = trt_module (input )
367
374
@@ -420,6 +427,7 @@ def forward(self, x):
420
427
torchtrt .save (trt_module , trt_ep_path )
421
428
422
429
deser_trt_module = torchtrt .load (trt_ep_path ).module ()
430
+ model .cuda ()
423
431
outputs_pyt = model (input )
424
432
outputs_trt = trt_module (input )
425
433
0 commit comments