Skip to content

Commit a265018

Browse files
committed
Fixed CI issues
1 parent 8b92866 commit a265018

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

tests/py/dynamo/models/test_export_serde.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def forward(self, x):
5252

5353
deser_trt_module = torchtrt.load(trt_ep_path).module()
5454
# Check Pyt and TRT exported program outputs
55+
model.cuda()
5556
cos_sim = cosine_similarity(model(input), trt_module(input)[0])
5657
assertions.assertTrue(
5758
cos_sim > COSINE_THRESHOLD,
@@ -106,6 +107,7 @@ def forward(self, x):
106107

107108
deser_trt_module = torchtrt.load(trt_ep_path).module()
108109
# Check Pyt and TRT exported program outputs
110+
model.cuda()
109111
outputs_pyt = model(input)
110112
outputs_trt = trt_module(input)
111113
for idx in range(len(outputs_pyt)):
@@ -162,8 +164,9 @@ def forward(self, x):
162164
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
163165
torchtrt.save(trt_module, trt_ep_path)
164166

165-
deser_trt_module = torchtrt.load(trt_ep_path).module()
167+
deser_trt_module = torchtrt.load(trt_ep_path).module().cuda()
166168
# Check Pyt and TRT exported program outputs
169+
model.cuda()
167170
outputs_pyt = model(input)
168171
outputs_trt = trt_module(input)
169172
for idx in range(len(outputs_pyt)):
@@ -173,7 +176,7 @@ def forward(self, x):
173176
msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
174177
)
175178

176-
# # Check Pyt and deserialized TRT exported program outputs
179+
# Check Pyt and deserialized TRT exported program outputs
177180
outputs_trt_deser = deser_trt_module(input)
178181
for idx in range(len(outputs_pyt)):
179182
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
@@ -224,6 +227,7 @@ def forward(self, x):
224227
torchtrt.save(trt_module, trt_ep_path)
225228

226229
deser_trt_module = torchtrt.load(trt_ep_path).module()
230+
model.cuda()
227231
outputs_pyt = model(input)
228232
outputs_trt = trt_module(input)
229233
for idx in range(len(outputs_pyt)):
@@ -267,6 +271,7 @@ def test_resnet18(ir):
267271
torchtrt.save(trt_module, trt_ep_path)
268272

269273
deser_trt_module = torchtrt.load(trt_ep_path).module()
274+
model.cuda()
270275
outputs_pyt = model(input)
271276
outputs_trt = trt_module(input)
272277
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
@@ -312,6 +317,7 @@ def test_resnet18_dynamic(ir):
312317
torchtrt.save(trt_module, trt_ep_path)
313318
# TODO: Enable this serialization issues are fixed
314319
# deser_trt_module = torchtrt.load(trt_ep_path).module()
320+
model.cuda()
315321
outputs_pyt = model(input)
316322
outputs_trt = trt_module(input)
317323
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
@@ -362,6 +368,7 @@ def forward(self, x):
362368
torchtrt.save(trt_module, trt_ep_path)
363369

364370
deser_trt_module = torchtrt.load(trt_ep_path).module()
371+
model.cuda()
365372
outputs_pyt = model(input)
366373
outputs_trt = trt_module(input)
367374

@@ -420,6 +427,7 @@ def forward(self, x):
420427
torchtrt.save(trt_module, trt_ep_path)
421428

422429
deser_trt_module = torchtrt.load(trt_ep_path).module()
430+
model.cuda()
423431
outputs_pyt = model(input)
424432
outputs_trt = trt_module(input)
425433

tests/py/dynamo/runtime/test_004_weight_streaming.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_weight_streaming_default(self, _, use_python_runtime):
6969

7070
# Weight streaming budget is reverted after the exit from weight streaming context
7171
assert weight_streaming_ctx.device_budget == requested_budget
72-
72+
model.cuda()
7373
ref = model(*input)
7474
out = optimized_model(*input)
7575

@@ -129,7 +129,7 @@ def test_weight_streaming_manual(self, _, use_python_runtime):
129129
assert weight_streaming_ctx.device_budget == requested_budget
130130

131131
out = optimized_model(*input)
132-
132+
model.cuda()
133133
ref = model(*input)
134134
torch.testing.assert_close(
135135
out.cpu(),
@@ -168,6 +168,7 @@ def test_weight_streaming_invalid_usage(self, _, use_python_runtime, multi_rt):
168168
)
169169

170170
# Setting weight streaming context to unsupported module
171+
model.cuda()
171172
with torchtrt.runtime.weight_streaming(model) as weight_streaming_ctx:
172173
streamable_budget = weight_streaming_ctx.total_device_budget
173174
assert streamable_budget == 0
@@ -222,7 +223,7 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime):
222223
# Budget distribution to multiple submodule may result in integer differences of at most 1
223224
assert abs(weight_streaming_ctx.device_budget - requested_budget) <= 1
224225
out = optimized_model(*input)
225-
226+
model.cuda()
226227
ref = model(*input)
227228
torch.testing.assert_close(
228229
out.cpu(),
@@ -273,7 +274,7 @@ def test_weight_streaming_cudagraphs(self, _, use_python_runtime):
273274
weight_streaming_ctx.device_budget = requested_budget
274275
for _ in range(4):
275276
out = cudagraphs_module(*input)
276-
277+
model.cuda()
277278
ref = model(*input)
278279
torch.testing.assert_close(
279280
out.cpu(),
@@ -356,7 +357,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
356357
exp_program,
357358
**compile_spec,
358359
)
359-
360+
model.cuda()
360361
# List of tuples representing different configurations for three features:
361362
# Cuda graphs, pre-allocated output buffer, weight streaming change
362363
states = list(itertools.product((True, False), repeat=3))

0 commit comments

Comments
 (0)