Skip to content

Commit 3dca15b

Browse files
authored
Fix model input in tests
1 parent 8cafd7a commit 3dca15b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/optim/models/test_clip_resnet50x4_text.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_clip_resnet50x4_text_load_and_forward(self) -> None:
2626
)
2727
# Start & End tokens: 49405, 49406
2828
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)])
29-
x = x.int()[None, :]
29+
x = x[None, :].long()
3030
model = clip_resnet50x4_text(pretrained=True)
3131
output = model(x)
3232
self.assertEqual(list(output.shape), [1, 640])
@@ -43,7 +43,7 @@ def test_clip_resnet50x4_text_forward_cuda(self) -> None:
4343
+ " not supporting CUDA."
4444
)
4545
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)]).cuda()
46-
x = x.int()[None, :]
46+
x = x[None, :].long()
4747
model = clip_resnet50x4_text(pretrained=True).cuda()
4848
output = model(x)
4949

@@ -57,7 +57,7 @@ def test_clip_resnet50x4_text_jit_module(self) -> None:
5757
+ " test due to insufficient Torch version."
5858
)
5959
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)])
60-
x = x.int()[None, :]
60+
x = x[None, :].long()
6161
model = clip_resnet50x4_text(pretrained=True)
6262
jit_model = torch.jit.script(model)
6363
output = jit_model(x)

0 commit comments

Comments
 (0)