Skip to content

Commit 0065b98

Browse files
committedAug 1, 2024
Support for timm_3d models
1 parent be40e83 commit 0065b98

File tree

1 file changed

+27
-24
lines changed
  • segmentation_models_pytorch_3d/base

1 file changed

+27
-24
lines changed
 

‎segmentation_models_pytorch_3d/base/model.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,33 @@ def initialize(self):
1212
def check_input_shape(self, x):
1313

1414
h, w, d = x.shape[-3:]
15-
if self.encoder.strides is not None:
16-
hs, ws, ds = 1, 1, 1
17-
for stride in self.encoder.strides:
18-
hs *= stride[0]
19-
ws *= stride[1]
20-
ds *= stride[2]
21-
if h % hs != 0 or w % ws != 0 or d % ds != 0:
22-
new_h = (h // hs + 1) * hs if h % hs != 0 else h
23-
new_w = (w // ws + 1) * ws if w % ws != 0 else w
24-
new_d = (d // ds + 1) * ds if d % ds != 0 else d
25-
raise RuntimeError(
26-
f"Wrong input shape height={h}, width={w}, depth={d}. Expected image height and width and depth "
27-
f"divisible by {hs}, {ws}, {ds}. Consider pad your images to shape ({new_h}, {new_w}, {new_d})."
28-
)
29-
else:
30-
output_stride = self.encoder.output_stride
31-
if h % output_stride != 0 or w % output_stride != 0 or d % output_stride != 0:
32-
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
33-
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
34-
new_d = (d // output_stride + 1) * output_stride if d % output_stride != 0 else d
35-
raise RuntimeError(
36-
f"Wrong input shape height={h}, width={w}, depth={d}. Expected image height and width and depth "
37-
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w}, {new_d})."
38-
)
15+
try:
16+
if self.encoder.strides is not None:
17+
hs, ws, ds = 1, 1, 1
18+
for stride in self.encoder.strides:
19+
hs *= stride[0]
20+
ws *= stride[1]
21+
ds *= stride[2]
22+
if h % hs != 0 or w % ws != 0 or d % ds != 0:
23+
new_h = (h // hs + 1) * hs if h % hs != 0 else h
24+
new_w = (w // ws + 1) * ws if w % ws != 0 else w
25+
new_d = (d // ds + 1) * ds if d % ds != 0 else d
26+
raise RuntimeError(
27+
f"Wrong input shape height={h}, width={w}, depth={d}. Expected image height and width and depth "
28+
f"divisible by {hs}, {ws}, {ds}. Consider pad your images to shape ({new_h}, {new_w}, {new_d})."
29+
)
30+
else:
31+
output_stride = self.encoder.output_stride
32+
if h % output_stride != 0 or w % output_stride != 0 or d % output_stride != 0:
33+
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
34+
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
35+
new_d = (d // output_stride + 1) * output_stride if d % output_stride != 0 else d
36+
raise RuntimeError(
37+
f"Wrong input shape height={h}, width={w}, depth={d}. Expected image height and width and depth "
38+
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w}, {new_d})."
39+
)
40+
except:
41+
pass
3942

4043
def forward(self, x):
4144
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

0 commit comments

Comments
 (0)
Please sign in to comment.