Skip to content

Commit 8cafd7a

Browse files
authored
Fix model tests
1 parent 26bdda2 commit 8cafd7a

File tree

3 files changed

+42
-38
lines changed

3 files changed

+42
-38
lines changed

captum/optim/models/_image/clip_resnet50x4_image.py

+14-21
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Optional, Type
1+
from typing import Any, Optional, Type
22
from warnings import warn
33

44
import torch
5+
import torch.nn as nn
56
from captum.optim.models._common import RedirectedReluLayer, SkipLayer
6-
from torch import nn
77

88
GS_SAVED_WEIGHTS_URL = (
99
"https://pytorch.s3.amazonaws.com/models/captum/clip_resnet50x4_image.pt"
@@ -14,7 +14,7 @@ def clip_resnet50x4_image(
1414
pretrained: bool = False,
1515
progress: bool = True,
1616
model_path: Optional[str] = None,
17-
**kwargs
17+
**kwargs: Any,
1818
) -> "CLIP_ResNet50x4Image":
1919
"""
2020
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
@@ -23,9 +23,8 @@ def clip_resnet50x4_image(
2323
This model can be combined with the CLIP ResNet 50x4 Text model to create the full
2424
CLIP ResNet 50x4 model.
2525
26-
AvgPool2d layers were replaced with AdaptiveAvgPool2d to allow for any input height
27-
and width size, though the best results are obtained by using the model's intended
28-
input height and width of 288x288.
26+
Note that model inputs are expected to have a shape of: [B, 3, 288, 288] or
27+
[3, 288, 288].
2928
3029
See here for more details:
3130
https://github.com/openai/CLIP
@@ -124,13 +123,13 @@ def __init__(
124123
self.conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1, bias=False)
125124
self.bn3 = nn.BatchNorm2d(80)
126125
self.relu3 = activ()
127-
self.avgpool = nn.AdaptiveAvgPool2d(72)
126+
self.avgpool = nn.AvgPool2d(2)
128127

129128
# Residual layers
130-
self.layer1 = self._build_layer(80, 80, 4, stride=1, pooling=72, activ=activ)
131-
self.layer2 = self._build_layer(320, 160, 6, stride=2, pooling=36, activ=activ)
132-
self.layer3 = self._build_layer(640, 320, 10, stride=2, pooling=18, activ=activ)
133-
self.layer4 = self._build_layer(1280, 640, 6, stride=2, pooling=9, activ=activ)
129+
self.layer1 = self._build_layer(80, 80, blocks=4, stride=1, activ=activ)
130+
self.layer2 = self._build_layer(320, 160, blocks=6, stride=2, activ=activ)
131+
self.layer3 = self._build_layer(640, 320, blocks=10, stride=2, activ=activ)
132+
self.layer4 = self._build_layer(1280, 640, blocks=6, stride=2, activ=activ)
134133

135134
# Attention Pooling
136135
self.attnpool = AttentionPool2d(9, 2560, out_features=640, num_heads=40)
@@ -141,7 +140,6 @@ def _build_layer(
141140
planes: int = 80,
142141
blocks: int = 4,
143142
stride: int = 1,
144-
pooling: int = 72,
145143
activ: Type[nn.Module] = nn.ReLU,
146144
) -> nn.Module:
147145
"""
@@ -160,18 +158,16 @@ def _build_layer(
160158
Default: 4
161159
stride (int, optional): The stride value to use for the Bottleneck layers.
162160
Default: 1
163-
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
164-
Default: 72
165161
activ (type of nn.Module, optional): The nn.Module class type to use for
166162
activation layers.
167163
Default: nn.ReLU
168164
169165
Returns:
170166
residual_layer (nn.Sequential): A full residual layer.
171167
"""
172-
layers = [Bottleneck(inplanes, planes, stride, pooling=pooling, activ=activ)]
168+
layers = [Bottleneck(inplanes, planes, stride, activ=activ)]
173169
for _ in range(blocks - 1):
174-
layers += [Bottleneck(planes * 4, planes, pooling=pooling, activ=activ)]
170+
layers += [Bottleneck(planes * 4, planes, activ=activ)]
175171
return nn.Sequential(*layers)
176172

177173
def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
@@ -230,7 +226,6 @@ def __init__(
230226
inplanes: int = 80,
231227
planes: int = 80,
232228
stride: int = 1,
233-
pooling: int = 72,
234229
activ: Type[nn.Module] = nn.ReLU,
235230
) -> None:
236231
"""
@@ -244,8 +239,6 @@ def __init__(
244239
Default: 80
245240
stride (int, optional): The stride value to use for the Bottleneck layers.
246241
Default: 1
247-
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
248-
Default: 72
249242
activ (type of nn.Module, optional): The nn.Module class type to use for
250243
activation layers.
251244
Default: nn.ReLU
@@ -259,15 +252,15 @@ def __init__(
259252
self.bn2 = nn.BatchNorm2d(planes)
260253
self.relu2 = activ()
261254

262-
self.avgpool = nn.AdaptiveAvgPool2d(pooling)
255+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
263256

264257
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
265258
self.bn3 = nn.BatchNorm2d(planes * 4)
266259
self.relu3 = activ()
267260

268261
if stride > 1 or inplanes != planes * 4:
269262
self.downsample = nn.Sequential(
270-
nn.AdaptiveAvgPool2d(pooling),
263+
nn.AvgPool2d(stride),
271264
nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=1, bias=False),
272265
nn.BatchNorm2d(planes * 4),
273266
)

captum/optim/models/_image/clip_resnet50x4_text.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
2-
from typing import Optional
2+
from typing import Any, Optional
33

44
import torch
5-
from torch import nn
5+
import torch.nn as nn
66

77

88
GS_SAVED_WEIGHTS_URL = (
@@ -14,7 +14,7 @@ def clip_resnet50x4_text(
1414
pretrained: bool = False,
1515
progress: bool = True,
1616
model_path: Optional[str] = None,
17-
**kwargs
17+
**kwargs: Any,
1818
) -> "CLIP_ResNet50x4Text":
1919
"""
2020
The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable

tests/optim/models/test_clip_resnet50x4_image.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_clip_resnet50x4_image_transform_warning(self) -> None:
6868
+ " to insufficient Torch version."
6969
)
7070
x = torch.stack(
71-
[torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0
71+
[torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0
7272
)
7373
model = clip_resnet50x4_image(pretrained=True)
7474
with self.assertWarns(UserWarning):
@@ -96,21 +96,18 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
9696
output = model(x)
9797
self.assertEqual(list(output.shape), [1, 640])
9898

99-
def test_clip_resnet50x4_image_load_and_forward_diff_sizes(self) -> None:
99+
def test_clip_resnet50x4_image_warning(self) -> None:
100100
if version.parse(torch.__version__) <= version.parse("1.6.0"):
101101
raise unittest.SkipTest(
102-
"Skipping pretrained CLIP ResNet 50x4 Image forward with different"
103-
+ " sized inputs test due to insufficient Torch version."
102+
"Skipping pretrained CLIP ResNet 50x4 Image transform input"
103+
+ " warning test due to insufficient Torch version."
104104
)
105-
x = torch.zeros(1, 3, 512, 512)
106-
x2 = torch.zeros(1, 3, 126, 224)
105+
x = torch.stack(
106+
[torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0
107+
)
107108
model = clip_resnet50x4_image(pretrained=True)
108-
109-
output = model(x)
110-
output2 = model(x2)
111-
112-
self.assertEqual(list(output.shape), [1, 640])
113-
self.assertEqual(list(output2.shape), [1, 640])
109+
with self.assertWarns(UserWarning):
110+
_ = model._transform_input(x)
114111

115112
def test_clip_resnet50x4_image_forward_cuda(self) -> None:
116113
if version.parse(torch.__version__) <= version.parse("1.6.0"):
@@ -123,7 +120,7 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
123120
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
124121
+ " not supporting CUDA."
125122
)
126-
x = torch.zeros(1, 3, 224, 224).cuda()
123+
x = torch.zeros(1, 3, 288, 288).cuda()
127124
model = clip_resnet50x4_image(pretrained=True).cuda()
128125
output = model(x)
129126

@@ -136,10 +133,24 @@ def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None:
136133
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
137134
+ " no redirected relu test due to insufficient Torch version."
138135
)
139-
x = torch.zeros(1, 3, 224, 224)
136+
x = torch.zeros(1, 3, 288, 288)
140137
model = clip_resnet50x4_image(
141138
pretrained=True, replace_relus_with_redirectedrelu=False
142139
)
143140
jit_model = torch.jit.script(model)
144141
output = jit_model(x)
145142
self.assertEqual(list(output.shape), [1, 640])
143+
144+
def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None:
145+
if version.parse(torch.__version__) <= version.parse("1.8.0"):
146+
raise unittest.SkipTest(
147+
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
148+
+ " redirected relu test due to insufficient Torch version."
149+
)
150+
x = torch.zeros(1, 3, 288, 288)
151+
model = clip_resnet50x4_image(
152+
pretrained=True, replace_relus_with_redirectedrelu=True
153+
)
154+
jit_model = torch.jit.script(model)
155+
output = jit_model(x)
156+
self.assertEqual(list(output.shape), [1, 640])

0 commit comments

Comments
 (0)