Skip to content

Commit 599d8e1

Browse files
authored
Update CLIP model for new testing & linting
1 parent e9598ea commit 599d8e1

File tree

4 files changed

+65
-61
lines changed

4 files changed

+65
-61
lines changed

captum/optim/models/_image/clip_resnet50x4_image.py

+15-22
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from typing import Optional, Type
1+
from typing import Any, Optional, Type
22
from warnings import warn
33

44
import torch
5-
from torch import nn
6-
5+
import torch.nn as nn
76
from captum.optim.models._common import RedirectedReluLayer, SkipLayer
87

98
GS_SAVED_WEIGHTS_URL = (
@@ -15,7 +14,7 @@ def clip_resnet50x4_image(
1514
pretrained: bool = False,
1615
progress: bool = True,
1716
model_path: Optional[str] = None,
18-
**kwargs
17+
**kwargs: Any,
1918
) -> "CLIP_ResNet50x4Image":
2019
"""
2120
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
@@ -24,9 +23,8 @@ def clip_resnet50x4_image(
2423
This model can be combined with the CLIP ResNet 50x4 Text model to create the full
2524
CLIP ResNet 50x4 model.
2625
27-
AvgPool2d layers were replaced with AdaptiveAvgPool2d to allow for any input height
28-
and width size, though the best results are obtained by using the model's intended
29-
input height and width of 288x288.
26+
Note that model inputs are expected to have a shape of: [B, 3, 288, 288] or
27+
[3, 288, 288].
3028
3129
See here for more details:
3230
https://github.com/openai/CLIP
@@ -82,6 +80,7 @@ class CLIP_ResNet50x4Image(nn.Module):
8280
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
8381
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
8482
"""
83+
8584
__constants__ = ["transform_input"]
8685

8786
def __init__(
@@ -124,13 +123,13 @@ def __init__(
124123
self.conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1, bias=False)
125124
self.bn3 = nn.BatchNorm2d(80)
126125
self.relu3 = activ()
127-
self.avgpool = nn.AdaptiveAvgPool2d(72)
126+
self.avgpool = nn.AvgPool2d(2)
128127

129128
# Residual layers
130-
self.layer1 = self._build_layer(80, 80, 4, stride=1, pooling=72, activ=activ)
131-
self.layer2 = self._build_layer(320, 160, 6, stride=2, pooling=36, activ=activ)
132-
self.layer3 = self._build_layer(640, 320, 10, stride=2, pooling=18, activ=activ)
133-
self.layer4 = self._build_layer(1280, 640, 6, stride=2, pooling=9, activ=activ)
129+
self.layer1 = self._build_layer(80, 80, blocks=4, stride=1, activ=activ)
130+
self.layer2 = self._build_layer(320, 160, blocks=6, stride=2, activ=activ)
131+
self.layer3 = self._build_layer(640, 320, blocks=10, stride=2, activ=activ)
132+
self.layer4 = self._build_layer(1280, 640, blocks=6, stride=2, activ=activ)
134133

135134
# Attention Pooling
136135
self.attnpool = AttentionPool2d(9, 2560, out_features=640, num_heads=40)
@@ -141,7 +140,6 @@ def _build_layer(
141140
planes: int = 80,
142141
blocks: int = 4,
143142
stride: int = 1,
144-
pooling: int = 72,
145143
activ: Type[nn.Module] = nn.ReLU,
146144
) -> nn.Module:
147145
"""
@@ -160,18 +158,16 @@ def _build_layer(
160158
Default: 4
161159
stride (int, optional): The stride value to use for the Bottleneck layers.
162160
Default: 1
163-
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
164-
Default: 72
165161
activ (type of nn.Module, optional): The nn.Module class type to use for
166162
activation layers.
167163
Default: nn.ReLU
168164
169165
Returns:
170166
residual_layer (nn.Sequential): A full residual layer.
171167
"""
172-
layers = [Bottleneck(inplanes, planes, stride, pooling=pooling, activ=activ)]
168+
layers = [Bottleneck(inplanes, planes, stride, activ=activ)]
173169
for _ in range(blocks - 1):
174-
layers += [Bottleneck(planes * 4, planes, pooling=pooling, activ=activ)]
170+
layers += [Bottleneck(planes * 4, planes, activ=activ)]
175171
return nn.Sequential(*layers)
176172

177173
def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
@@ -230,7 +226,6 @@ def __init__(
230226
inplanes: int = 80,
231227
planes: int = 80,
232228
stride: int = 1,
233-
pooling: int = 72,
234229
activ: Type[nn.Module] = nn.ReLU,
235230
) -> None:
236231
"""
@@ -244,8 +239,6 @@ def __init__(
244239
Default: 80
245240
stride (int, optional): The stride value to use for the Bottleneck layers.
246241
Default: 1
247-
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
248-
Default: 72
249242
activ (type of nn.Module, optional): The nn.Module class type to use for
250243
activation layers.
251244
Default: nn.ReLU
@@ -259,15 +252,15 @@ def __init__(
259252
self.bn2 = nn.BatchNorm2d(planes)
260253
self.relu2 = activ()
261254

262-
self.avgpool = nn.AdaptiveAvgPool2d(pooling)
255+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
263256

264257
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
265258
self.bn3 = nn.BatchNorm2d(planes * 4)
266259
self.relu3 = activ()
267260

268261
if stride > 1 or inplanes != planes * 4:
269262
self.downsample = nn.Sequential(
270-
nn.AdaptiveAvgPool2d(pooling),
263+
nn.AvgPool2d(stride),
271264
nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=1, bias=False),
272265
nn.BatchNorm2d(planes * 4),
273266
)

captum/optim/models/_image/clip_resnet50x4_text.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Optional
2-
31
import math
2+
from typing import Any, Optional
3+
44
import torch
5-
from torch import nn
5+
import torch.nn as nn
66

77

88
GS_SAVED_WEIGHTS_URL = (
@@ -14,7 +14,7 @@ def clip_resnet50x4_text(
1414
pretrained: bool = False,
1515
progress: bool = True,
1616
model_path: Optional[str] = None,
17-
**kwargs
17+
**kwargs: Any,
1818
) -> "CLIP_ResNet50x4Text":
1919
"""
2020
The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
@@ -72,6 +72,7 @@ class CLIP_ResNet50x4Text(nn.Module):
7272
The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
7373
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
7474
"""
75+
7576
def __init__(
7677
self,
7778
width: int = 640,
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
#!/usr/bin/env python3
22
import unittest
3-
from typing import Type
43

54
import torch
6-
75
from captum.optim.models import clip_resnet50x4_image
86
from captum.optim.models._common import RedirectedReluLayer, SkipLayer
7+
from packaging import version
98
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
109
from tests.optim.helpers.models import check_layer_in_model
1110

1211

1312
class TestCLIPResNet50x4Image(BaseTest):
1413
def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None:
15-
if torch.__version__ <= "1.6.0":
14+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
1615
raise unittest.SkipTest(
1716
"Skipping load pretrained CLIP ResNet 50x4 Image due to insufficient"
1817
+ " Torch version."
@@ -23,7 +22,7 @@ def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None:
2322
self.assertTrue(check_layer_in_model(model, RedirectedReluLayer))
2423

2524
def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None:
26-
if torch.__version__ <= "1.6.0":
25+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
2726
raise unittest.SkipTest(
2827
"Skipping load pretrained CLIP ResNet 50x4 Image RedirectedRelu test"
2928
+ " due to insufficient Torch version."
@@ -35,7 +34,7 @@ def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None:
3534
self.assertTrue(check_layer_in_model(model, torch.nn.ReLU))
3635

3736
def test_load_clip_resnet50x4_image_linear(self) -> None:
38-
if torch.__version__ <= "1.6.0":
37+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
3938
raise unittest.SkipTest(
4039
"Skipping load pretrained CLIP ResNet 50x4 Image linear test due to"
4140
+ " insufficient Torch version."
@@ -46,7 +45,7 @@ def test_load_clip_resnet50x4_image_linear(self) -> None:
4645
self.assertTrue(check_layer_in_model(model, SkipLayer))
4746

4847
def test_clip_resnet50x4_image_transform(self) -> None:
49-
if torch.__version__ <= "1.6.0":
48+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
5049
raise unittest.SkipTest(
5150
"Skipping CLIP ResNet 50x4 Image internal transform test due to"
5251
+ " insufficient Torch version."
@@ -63,20 +62,20 @@ def test_clip_resnet50x4_image_transform(self) -> None:
6362
assertTensorAlmostEqual(self, output, expected_output, 0)
6463

6564
def test_clip_resnet50x4_image_transform_warning(self) -> None:
66-
if torch.__version__ <= "1.6.0":
65+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
6766
raise unittest.SkipTest(
6867
"Skipping CLIP ResNet 50x4 Image internal transform warning test due"
6968
+ " to insufficient Torch version."
7069
)
7170
x = torch.stack(
72-
[torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0
71+
[torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0
7372
)
7473
model = clip_resnet50x4_image(pretrained=True)
7574
with self.assertWarns(UserWarning):
7675
model._transform_input(x)
7776

7877
def test_clip_resnet50x4_image_load_and_forward(self) -> None:
79-
if torch.__version__ <= "1.6.0":
78+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
8079
raise unittest.SkipTest(
8180
"Skipping basic pretrained CLIP ResNet 50x4 Image forward test due to"
8281
+ " insufficient Torch version."
@@ -87,7 +86,7 @@ def test_clip_resnet50x4_image_load_and_forward(self) -> None:
8786
self.assertEqual(list(output.shape), [1, 640])
8887

8988
def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
90-
if torch.__version__ <= "1.6.0":
89+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
9190
raise unittest.SkipTest(
9291
"Skipping basic untrained CLIP ResNet 50x4 Image forward test due to"
9392
+ " insufficient Torch version."
@@ -97,24 +96,21 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
9796
output = model(x)
9897
self.assertEqual(list(output.shape), [1, 640])
9998

100-
def test_clip_resnet50x4_image_load_and_forward_diff_sizes(self) -> None:
101-
if torch.__version__ <= "1.6.0":
99+
def test_clip_resnet50x4_image_warning(self) -> None:
100+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
102101
raise unittest.SkipTest(
103-
"Skipping pretrained CLIP ResNet 50x4 Image forward with different"
104-
+ " sized inputs test due to insufficient Torch version."
102+
"Skipping pretrained CLIP ResNet 50x4 Image transform input"
103+
+ " warning test due to insufficient Torch version."
105104
)
106-
x = torch.zeros(1, 3, 512, 512)
107-
x2 = torch.zeros(1, 3, 126, 224)
105+
x = torch.stack(
106+
[torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0
107+
)
108108
model = clip_resnet50x4_image(pretrained=True)
109-
110-
output = model(x)
111-
output2 = model(x2)
112-
113-
self.assertEqual(list(output.shape), [1, 640])
114-
self.assertEqual(list(output2.shape), [1, 640])
109+
with self.assertWarns(UserWarning):
110+
_ = model._transform_input(x)
115111

116112
def test_clip_resnet50x4_image_forward_cuda(self) -> None:
117-
if torch.__version__ <= "1.6.0":
113+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
118114
raise unittest.SkipTest(
119115
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
120116
+ " insufficient Torch version."
@@ -124,23 +120,37 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
124120
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
125121
+ " not supporting CUDA."
126122
)
127-
x = torch.zeros(1, 3, 224, 224).cuda()
123+
x = torch.zeros(1, 3, 288, 288).cuda()
128124
model = clip_resnet50x4_image(pretrained=True).cuda()
129125
output = model(x)
130126

131127
self.assertTrue(output.is_cuda)
132128
self.assertEqual(list(output.shape), [1, 640])
133129

134130
def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None:
135-
if torch.__version__ <= "1.8.0":
131+
if version.parse(torch.__version__) <= version.parse("1.8.0"):
136132
raise unittest.SkipTest(
137133
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
138134
+ " no redirected relu test due to insufficient Torch version."
139135
)
140-
x = torch.zeros(1, 3, 224, 224)
136+
x = torch.zeros(1, 3, 288, 288)
141137
model = clip_resnet50x4_image(
142138
pretrained=True, replace_relus_with_redirectedrelu=False
143139
)
144140
jit_model = torch.jit.script(model)
145141
output = jit_model(x)
146142
self.assertEqual(list(output.shape), [1, 640])
143+
144+
def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None:
145+
if version.parse(torch.__version__) <= version.parse("1.8.0"):
146+
raise unittest.SkipTest(
147+
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
148+
+ " redirected relu test due to insufficient Torch version."
149+
)
150+
x = torch.zeros(1, 3, 288, 288)
151+
model = clip_resnet50x4_image(
152+
pretrained=True, replace_relus_with_redirectedrelu=True
153+
)
154+
jit_model = torch.jit.script(model)
155+
output = jit_model(x)
156+
self.assertEqual(list(output.shape), [1, 640])

tests/optim/models/test_clip_resnet50x4_text.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,37 @@
22
import unittest
33

44
import torch
5-
65
from captum.optim.models import clip_resnet50x4_text
6+
from packaging import version
77
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
88

99

1010
class TestCLIPResNet50x4Text(BaseTest):
1111
def test_clip_resnet50x4_text_logit_scale(self) -> None:
12-
if torch.__version__ <= "1.6.0":
12+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
1313
raise unittest.SkipTest(
1414
"Skipping basic pretrained CLIP ResNet 50x4 Text logit scale test due"
1515
+ " to insufficient Torch version."
1616
)
1717
model = clip_resnet50x4_text(pretrained=True)
18-
expected_logit_scale = torch.tensor([4.605170249938965])
18+
expected_logit_scale = torch.tensor(4.605170249938965)
1919
assertTensorAlmostEqual(self, model.logit_scale, expected_logit_scale)
2020

2121
def test_clip_resnet50x4_text_load_and_forward(self) -> None:
22-
if torch.__version__ <= "1.6.0":
22+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
2323
raise unittest.SkipTest(
2424
"Skipping basic pretrained CLIP ResNet 50x4 Text forward test due to"
2525
+ " insufficient Torch version."
2626
)
2727
# Start & End tokens: 49405, 49406
2828
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)])
29-
x = x.int()[None, :]
29+
x = x[None, :].long()
3030
model = clip_resnet50x4_text(pretrained=True)
3131
output = model(x)
3232
self.assertEqual(list(output.shape), [1, 640])
3333

3434
def test_clip_resnet50x4_text_forward_cuda(self) -> None:
35-
if torch.__version__ <= "1.6.0":
35+
if version.parse(torch.__version__) <= version.parse("1.6.0"):
3636
raise unittest.SkipTest(
3737
"Skipping pretrained CLIP ResNet 50x4 Text forward CUDA test due to"
3838
+ " insufficient Torch version."
@@ -43,21 +43,21 @@ def test_clip_resnet50x4_text_forward_cuda(self) -> None:
4343
+ " not supporting CUDA."
4444
)
4545
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)]).cuda()
46-
x = x.int()[None, :]
46+
x = x[None, :].long()
4747
model = clip_resnet50x4_text(pretrained=True).cuda()
4848
output = model(x)
4949

5050
self.assertTrue(output.is_cuda)
5151
self.assertEqual(list(output.shape), [1, 640])
5252

5353
def test_clip_resnet50x4_text_jit_module(self) -> None:
54-
if torch.__version__ <= "1.8.0":
54+
if version.parse(torch.__version__) <= version.parse("1.8.0"):
5555
raise unittest.SkipTest(
5656
"Skipping pretrained CLIP ResNet 50x4 Text load & JIT module"
5757
+ " test due to insufficient Torch version."
5858
)
5959
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)])
60-
x = x.int()[None, :]
60+
x = x[None, :].long()
6161
model = clip_resnet50x4_text(pretrained=True)
6262
jit_model = torch.jit.script(model)
6363
output = jit_model(x)

0 commit comments

Comments
 (0)