11#!/usr/bin/env python3
22import unittest
3- from typing import Type
43
54import torch
6-
75from captum .optim .models import clip_resnet50x4_image
86from captum .optim .models ._common import RedirectedReluLayer , SkipLayer
7+ from packaging import version
98from tests .helpers .basic import BaseTest , assertTensorAlmostEqual
109from tests .optim .helpers .models import check_layer_in_model
1110
1211
1312class TestCLIPResNet50x4Image (BaseTest ):
1413 def test_load_clip_resnet50x4_image_with_redirected_relu (self ) -> None :
15- if torch .__version__ <= "1.6.0" :
14+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
1615 raise unittest .SkipTest (
1716 "Skipping load pretrained CLIP ResNet 50x4 Image due to insufficient"
1817 + " Torch version."
@@ -23,7 +22,7 @@ def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None:
2322 self .assertTrue (check_layer_in_model (model , RedirectedReluLayer ))
2423
2524 def test_load_clip_resnet50x4_image_no_redirected_relu (self ) -> None :
26- if torch .__version__ <= "1.6.0" :
25+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
2726 raise unittest .SkipTest (
2827 "Skipping load pretrained CLIP ResNet 50x4 Image RedirectedRelu test"
2928 + " due to insufficient Torch version."
@@ -35,7 +34,7 @@ def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None:
3534 self .assertTrue (check_layer_in_model (model , torch .nn .ReLU ))
3635
3736 def test_load_clip_resnet50x4_image_linear (self ) -> None :
38- if torch .__version__ <= "1.6.0" :
37+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
3938 raise unittest .SkipTest (
4039 "Skipping load pretrained CLIP ResNet 50x4 Image linear test due to"
4140 + " insufficient Torch version."
@@ -46,7 +45,7 @@ def test_load_clip_resnet50x4_image_linear(self) -> None:
4645 self .assertTrue (check_layer_in_model (model , SkipLayer ))
4746
4847 def test_clip_resnet50x4_image_transform (self ) -> None :
49- if torch .__version__ <= "1.6.0" :
48+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
5049 raise unittest .SkipTest (
5150 "Skipping CLIP ResNet 50x4 Image internal transform test due to"
5251 + " insufficient Torch version."
@@ -63,20 +62,20 @@ def test_clip_resnet50x4_image_transform(self) -> None:
6362 assertTensorAlmostEqual (self , output , expected_output , 0 )
6463
6564 def test_clip_resnet50x4_image_transform_warning (self ) -> None :
66- if torch .__version__ <= "1.6.0" :
65+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
6766 raise unittest .SkipTest (
6867 "Skipping CLIP ResNet 50x4 Image internal transform warning test due"
6968 + " to insufficient Torch version."
7069 )
7170 x = torch .stack (
72- [torch .ones (3 , 112 , 112 ) * - 1 , torch .ones (3 , 112 , 112 ) * 2 ], dim = 0
71+ [torch .ones (3 , 288 , 288 ) * - 1 , torch .ones (3 , 288 , 288 ) * 2 ], dim = 0
7372 )
7473 model = clip_resnet50x4_image (pretrained = True )
7574 with self .assertWarns (UserWarning ):
7675 model ._transform_input (x )
7776
7877 def test_clip_resnet50x4_image_load_and_forward (self ) -> None :
79- if torch .__version__ <= "1.6.0" :
78+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
8079 raise unittest .SkipTest (
8180 "Skipping basic pretrained CLIP ResNet 50x4 Image forward test due to"
8281 + " insufficient Torch version."
@@ -87,7 +86,7 @@ def test_clip_resnet50x4_image_load_and_forward(self) -> None:
8786 self .assertEqual (list (output .shape ), [1 , 640 ])
8887
8988 def test_untrained_clip_resnet50x4_image_load_and_forward (self ) -> None :
90- if torch .__version__ <= "1.6.0" :
89+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
9190 raise unittest .SkipTest (
9291 "Skipping basic untrained CLIP ResNet 50x4 Image forward test due to"
9392 + " insufficient Torch version."
@@ -97,24 +96,21 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
9796 output = model (x )
9897 self .assertEqual (list (output .shape ), [1 , 640 ])
9998
100- def test_clip_resnet50x4_image_load_and_forward_diff_sizes (self ) -> None :
101- if torch .__version__ <= "1.6.0" :
99+ def test_clip_resnet50x4_image_warning (self ) -> None :
100+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
102101 raise unittest .SkipTest (
103- "Skipping pretrained CLIP ResNet 50x4 Image forward with different "
104- + " sized inputs test due to insufficient Torch version."
102+ "Skipping pretrained CLIP ResNet 50x4 Image transform input "
103+ + " warning test due to insufficient Torch version."
105104 )
106- x = torch .zeros (1 , 3 , 512 , 512 )
107- x2 = torch .zeros (1 , 3 , 126 , 224 )
105+ x = torch .stack (
106+ [torch .ones (3 , 288 , 288 ) * - 1 , torch .ones (3 , 288 , 288 ) * 2 ], dim = 0
107+ )
108108 model = clip_resnet50x4_image (pretrained = True )
109-
110- output = model (x )
111- output2 = model (x2 )
112-
113- self .assertEqual (list (output .shape ), [1 , 640 ])
114- self .assertEqual (list (output2 .shape ), [1 , 640 ])
109+ with self .assertWarns (UserWarning ):
110+ _ = model ._transform_input (x )
115111
116112 def test_clip_resnet50x4_image_forward_cuda (self ) -> None :
117- if torch .__version__ <= "1.6.0" :
113+ if version . parse ( torch .__version__ ) <= version . parse ( "1.6.0" ) :
118114 raise unittest .SkipTest (
119115 "Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
120116 + " insufficient Torch version."
@@ -124,23 +120,37 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
124120 "Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
125121 + " not supporting CUDA."
126122 )
127- x = torch .zeros (1 , 3 , 224 , 224 ).cuda ()
123+ x = torch .zeros (1 , 3 , 288 , 288 ).cuda ()
128124 model = clip_resnet50x4_image (pretrained = True ).cuda ()
129125 output = model (x )
130126
131127 self .assertTrue (output .is_cuda )
132128 self .assertEqual (list (output .shape ), [1 , 640 ])
133129
134130 def test_clip_resnet50x4_image_jit_module_no_redirected_relu (self ) -> None :
135- if torch .__version__ <= "1.8.0" :
131+ if version . parse ( torch .__version__ ) <= version . parse ( "1.8.0" ) :
136132 raise unittest .SkipTest (
137133 "Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
138134 + " no redirected relu test due to insufficient Torch version."
139135 )
140- x = torch .zeros (1 , 3 , 224 , 224 )
136+ x = torch .zeros (1 , 3 , 288 , 288 )
141137 model = clip_resnet50x4_image (
142138 pretrained = True , replace_relus_with_redirectedrelu = False
143139 )
144140 jit_model = torch .jit .script (model )
145141 output = jit_model (x )
146142 self .assertEqual (list (output .shape ), [1 , 640 ])
143+
144+ def test_clip_resnet50x4_image_jit_module_with_redirected_relu (self ) -> None :
145+ if version .parse (torch .__version__ ) <= version .parse ("1.8.0" ):
146+ raise unittest .SkipTest (
147+ "Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
148+ + " redirected relu test due to insufficient Torch version."
149+ )
150+ x = torch .zeros (1 , 3 , 288 , 288 )
151+ model = clip_resnet50x4_image (
152+ pretrained = True , replace_relus_with_redirectedrelu = True
153+ )
154+ jit_model = torch .jit .script (model )
155+ output = jit_model (x )
156+ self .assertEqual (list (output .shape ), [1 , 640 ])
0 commit comments