1
1
#!/usr/bin/env python3
2
2
import unittest
3
- from typing import Type
4
3
5
4
import torch
6
-
7
5
from captum .optim .models import clip_resnet50x4_image
8
6
from captum .optim .models ._common import RedirectedReluLayer , SkipLayer
7
+ from packaging import version
9
8
from tests .helpers .basic import BaseTest , assertTensorAlmostEqual
10
9
from tests .optim .helpers .models import check_layer_in_model
11
10
12
11
13
12
class TestCLIPResNet50x4Image (BaseTest ):
14
13
def test_load_clip_resnet50x4_image_with_redirected_relu (self ) -> None :
15
- if torch .__version__ <= "1.6.0" :
14
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
16
15
raise unittest .SkipTest (
17
16
"Skipping load pretrained CLIP ResNet 50x4 Image due to insufficient"
18
17
+ " Torch version."
@@ -23,7 +22,7 @@ def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None:
23
22
self .assertTrue (check_layer_in_model (model , RedirectedReluLayer ))
24
23
25
24
def test_load_clip_resnet50x4_image_no_redirected_relu (self ) -> None :
26
- if torch .__version__ <= "1.6.0" :
25
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
27
26
raise unittest .SkipTest (
28
27
"Skipping load pretrained CLIP ResNet 50x4 Image RedirectedRelu test"
29
28
+ " due to insufficient Torch version."
@@ -35,7 +34,7 @@ def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None:
35
34
self .assertTrue (check_layer_in_model (model , torch .nn .ReLU ))
36
35
37
36
def test_load_clip_resnet50x4_image_linear (self ) -> None :
38
- if torch .__version__ <= "1.6.0" :
37
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
39
38
raise unittest .SkipTest (
40
39
"Skipping load pretrained CLIP ResNet 50x4 Image linear test due to"
41
40
+ " insufficient Torch version."
@@ -46,7 +45,7 @@ def test_load_clip_resnet50x4_image_linear(self) -> None:
46
45
self .assertTrue (check_layer_in_model (model , SkipLayer ))
47
46
48
47
def test_clip_resnet50x4_image_transform (self ) -> None :
49
- if torch .__version__ <= "1.6.0" :
48
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
50
49
raise unittest .SkipTest (
51
50
"Skipping CLIP ResNet 50x4 Image internal transform test due to"
52
51
+ " insufficient Torch version."
@@ -63,20 +62,20 @@ def test_clip_resnet50x4_image_transform(self) -> None:
63
62
assertTensorAlmostEqual (self , output , expected_output , 0 )
64
63
65
64
def test_clip_resnet50x4_image_transform_warning (self ) -> None :
66
- if torch .__version__ <= "1.6.0" :
65
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
67
66
raise unittest .SkipTest (
68
67
"Skipping CLIP ResNet 50x4 Image internal transform warning test due"
69
68
+ " to insufficient Torch version."
70
69
)
71
70
x = torch .stack (
72
- [torch .ones (3 , 112 , 112 ) * - 1 , torch .ones (3 , 112 , 112 ) * 2 ], dim = 0
71
+ [torch .ones (3 , 288 , 288 ) * - 1 , torch .ones (3 , 288 , 288 ) * 2 ], dim = 0
73
72
)
74
73
model = clip_resnet50x4_image (pretrained = True )
75
74
with self .assertWarns (UserWarning ):
76
75
model ._transform_input (x )
77
76
78
77
def test_clip_resnet50x4_image_load_and_forward (self ) -> None :
79
- if torch .__version__ <= "1.6.0" :
78
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
80
79
raise unittest .SkipTest (
81
80
"Skipping basic pretrained CLIP ResNet 50x4 Image forward test due to"
82
81
+ " insufficient Torch version."
@@ -87,7 +86,7 @@ def test_clip_resnet50x4_image_load_and_forward(self) -> None:
87
86
self .assertEqual (list (output .shape ), [1 , 640 ])
88
87
89
88
def test_untrained_clip_resnet50x4_image_load_and_forward (self ) -> None :
90
- if torch .__version__ <= "1.6.0" :
89
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
91
90
raise unittest .SkipTest (
92
91
"Skipping basic untrained CLIP ResNet 50x4 Image forward test due to"
93
92
+ " insufficient Torch version."
@@ -97,24 +96,21 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
97
96
output = model (x )
98
97
self .assertEqual (list (output .shape ), [1 , 640 ])
99
98
100
- def test_clip_resnet50x4_image_load_and_forward_diff_sizes (self ) -> None :
101
- if torch .__version__ <= "1.6.0" :
99
+ def test_clip_resnet50x4_image_warning (self ) -> None :
100
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
102
101
raise unittest .SkipTest (
103
- "Skipping pretrained CLIP ResNet 50x4 Image forward with different "
104
- + " sized inputs test due to insufficient Torch version."
102
+ "Skipping pretrained CLIP ResNet 50x4 Image transform input "
103
+ + " warning test due to insufficient Torch version."
105
104
)
106
- x = torch .zeros (1 , 3 , 512 , 512 )
107
- x2 = torch .zeros (1 , 3 , 126 , 224 )
105
+ x = torch .stack (
106
+ [torch .ones (3 , 288 , 288 ) * - 1 , torch .ones (3 , 288 , 288 ) * 2 ], dim = 0
107
+ )
108
108
model = clip_resnet50x4_image (pretrained = True )
109
-
110
- output = model (x )
111
- output2 = model (x2 )
112
-
113
- self .assertEqual (list (output .shape ), [1 , 640 ])
114
- self .assertEqual (list (output2 .shape ), [1 , 640 ])
109
+ with self .assertWarns (UserWarning ):
110
+ _ = model ._transform_input (x )
115
111
116
112
def test_clip_resnet50x4_image_forward_cuda (self ) -> None :
117
- if torch .__version__ <= "1.6.0" :
113
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
118
114
raise unittest .SkipTest (
119
115
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
120
116
+ " insufficient Torch version."
@@ -124,23 +120,37 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
124
120
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
125
121
+ " not supporting CUDA."
126
122
)
127
- x = torch .zeros (1 , 3 , 224 , 224 ).cuda ()
123
+ x = torch .zeros (1 , 3 , 288 , 288 ).cuda ()
128
124
model = clip_resnet50x4_image (pretrained = True ).cuda ()
129
125
output = model (x )
130
126
131
127
self .assertTrue (output .is_cuda )
132
128
self .assertEqual (list (output .shape ), [1 , 640 ])
133
129
134
130
def test_clip_resnet50x4_image_jit_module_no_redirected_relu (self ) -> None :
135
- if torch .__version__ <= "1.8.0" :
131
+ if version . parse ( torch .__version__ ) <= version . parse ( "1.8.0" ) :
136
132
raise unittest .SkipTest (
137
133
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
138
134
+ " no redirected relu test due to insufficient Torch version."
139
135
)
140
- x = torch .zeros (1 , 3 , 224 , 224 )
136
+ x = torch .zeros (1 , 3 , 288 , 288 )
141
137
model = clip_resnet50x4_image (
142
138
pretrained = True , replace_relus_with_redirectedrelu = False
143
139
)
144
140
jit_model = torch .jit .script (model )
145
141
output = jit_model (x )
146
142
self .assertEqual (list (output .shape ), [1 , 640 ])
143
+
144
+ def test_clip_resnet50x4_image_jit_module_with_redirected_relu (self ) -> None :
145
+ if version .parse (torch .__version__ ) <= version .parse ("1.8.0" ):
146
+ raise unittest .SkipTest (
147
+ "Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
148
+ + " redirected relu test due to insufficient Torch version."
149
+ )
150
+ x = torch .zeros (1 , 3 , 288 , 288 )
151
+ model = clip_resnet50x4_image (
152
+ pretrained = True , replace_relus_with_redirectedrelu = True
153
+ )
154
+ jit_model = torch .jit .script (model )
155
+ output = jit_model (x )
156
+ self .assertEqual (list (output .shape ), [1 , 640 ])
0 commit comments