@@ -12,30 +12,33 @@ def initialize(self):
12
12
def check_input_shape (self , x ):
13
13
14
14
h , w , d = x .shape [- 3 :]
15
- if self .encoder .strides is not None :
16
- hs , ws , ds = 1 , 1 , 1
17
- for stride in self .encoder .strides :
18
- hs *= stride [0 ]
19
- ws *= stride [1 ]
20
- ds *= stride [2 ]
21
- if h % hs != 0 or w % ws != 0 or d % ds != 0 :
22
- new_h = (h // hs + 1 ) * hs if h % hs != 0 else h
23
- new_w = (w // ws + 1 ) * ws if w % ws != 0 else w
24
- new_d = (d // ds + 1 ) * ds if d % ds != 0 else d
25
- raise RuntimeError (
26
- f"Wrong input shape height={ h } , width={ w } , depth={ d } . Expected image height and width and depth "
27
- f"divisible by { hs } , { ws } , { ds } . Consider pad your images to shape ({ new_h } , { new_w } , { new_d } )."
28
- )
29
- else :
30
- output_stride = self .encoder .output_stride
31
- if h % output_stride != 0 or w % output_stride != 0 or d % output_stride != 0 :
32
- new_h = (h // output_stride + 1 ) * output_stride if h % output_stride != 0 else h
33
- new_w = (w // output_stride + 1 ) * output_stride if w % output_stride != 0 else w
34
- new_d = (d // output_stride + 1 ) * output_stride if d % output_stride != 0 else d
35
- raise RuntimeError (
36
- f"Wrong input shape height={ h } , width={ w } , depth={ d } . Expected image height and width and depth "
37
- f"divisible by { output_stride } . Consider pad your images to shape ({ new_h } , { new_w } , { new_d } )."
38
- )
15
+ try :
16
+ if self .encoder .strides is not None :
17
+ hs , ws , ds = 1 , 1 , 1
18
+ for stride in self .encoder .strides :
19
+ hs *= stride [0 ]
20
+ ws *= stride [1 ]
21
+ ds *= stride [2 ]
22
+ if h % hs != 0 or w % ws != 0 or d % ds != 0 :
23
+ new_h = (h // hs + 1 ) * hs if h % hs != 0 else h
24
+ new_w = (w // ws + 1 ) * ws if w % ws != 0 else w
25
+ new_d = (d // ds + 1 ) * ds if d % ds != 0 else d
26
+ raise RuntimeError (
27
+ f"Wrong input shape height={ h } , width={ w } , depth={ d } . Expected image height and width and depth "
28
+ f"divisible by { hs } , { ws } , { ds } . Consider pad your images to shape ({ new_h } , { new_w } , { new_d } )."
29
+ )
30
+ else :
31
+ output_stride = self .encoder .output_stride
32
+ if h % output_stride != 0 or w % output_stride != 0 or d % output_stride != 0 :
33
+ new_h = (h // output_stride + 1 ) * output_stride if h % output_stride != 0 else h
34
+ new_w = (w // output_stride + 1 ) * output_stride if w % output_stride != 0 else w
35
+ new_d = (d // output_stride + 1 ) * output_stride if d % output_stride != 0 else d
36
+ raise RuntimeError (
37
+ f"Wrong input shape height={ h } , width={ w } , depth={ d } . Expected image height and width and depth "
38
+ f"divisible by { output_stride } . Consider pad your images to shape ({ new_h } , { new_w } , { new_d } )."
39
+ )
40
+ except :
41
+ pass
39
42
40
43
def forward (self , x ):
41
44
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
0 commit comments