Skip to content

Commit fad7fb6

Browse files
Update UNETR to enable resize to longest side (#192)
Add resize functionality to the UNETR and use it by default
1 parent 7111f72 commit fad7fb6

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

test/model/test_unetr.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,24 @@ def test_unetr(self):
2626
from torch_em.model import UNETR
2727

2828
model = UNETR()
29-
self._test_net(model, (1, 3, 256, 256))
29+
self._test_net(model, (1, 3, 512, 512))
30+
31+
def test_unetr_no_resize(self):
32+
from torch_em.model import UNETR
33+
34+
model = UNETR(resize_input=False)
35+
self._test_net(model, (1, 3, 512, 512))
3036

3137
@unittest.skipIf(micro_sam is None, "Needs micro_sam")
3238
def test_unetr_from_sam(self):
33-
from torch_em.model import build_unetr_with_sam_intialization
39+
from torch_em.model import UNETR
40+
from micro_sam.util import models
41+
42+
model_registry = models()
43+
checkpoint = model_registry.fetch("vit_b")
3444

35-
model = build_unetr_with_sam_intialization()
36-
self._test_net(model, (1, 3, 256, 256))
45+
model = UNETR(encoder_checkpoint=checkpoint)
46+
self._test_net(model, (1, 3, 512, 512))
3747

3848

3949
if __name__ == "__main__":

torch_em/model/unetr.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
3636
)
3737
encoder_state = model.image_encoder.state_dict()
3838
except Exception:
39-
# If we have a MAE encoder, then we directly load the encoder state
40-
# from the checkpoint.
39+
# Try loading the encoder state directly from a checkpoint.
4140
encoder_state = torch.load(checkpoint)
4241

4342
elif backbone == "mae":
@@ -68,16 +67,18 @@ def __init__(
6867
out_channels: int = 1,
6968
use_sam_stats: bool = False,
7069
use_mae_stats: bool = False,
70+
resize_input: bool = True,
7171
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
7272
final_activation: Optional[Union[str, nn.Module]] = None,
7373
use_skip_connection: bool = True,
74-
embed_dim: Optional[int] = None
74+
embed_dim: Optional[int] = None,
7575
) -> None:
7676
super().__init__()
7777

7878
self.use_sam_stats = use_sam_stats
7979
self.use_mae_stats = use_mae_stats
8080
self.use_skip_connection = use_skip_connection
81+
self.resize_input = resize_input
8182

8283
if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h"
8384
print(f"Using {encoder} from {backbone.upper()}")
@@ -152,25 +153,49 @@ def _get_activation(self, activation):
152153
raise ValueError(f"Invalid activation: {activation}")
153154
return return_activation()
154155

156+
@staticmethod
157+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
158+
"""Compute the output size given input size and target long side length.
159+
"""
160+
scale = long_side_length * 1.0 / max(oldh, oldw)
161+
newh, neww = oldh * scale, oldw * scale
162+
neww = int(neww + 0.5)
163+
newh = int(newh + 0.5)
164+
return (newh, neww)
165+
166+
def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
167+
"""Resizes the image so that the longest side has the correct length.
168+
169+
Expects batched images with shape BxCxHxW and float format.
170+
"""
171+
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
172+
return F.interpolate(
173+
image, target_size, mode="bilinear", align_corners=False, antialias=True
174+
)
175+
155176
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
156-
device = "cuda" if torch.cuda.is_available() else "cpu"
177+
device = x.device
157178

158179
if self.use_sam_stats:
159-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
160-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
180+
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
181+
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
161182
elif self.use_mae_stats:
162183
# TODO: add mean std from mae experiments (or open up arguments for this)
163184
raise NotImplementedError
164185
else:
165-
pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(-1, 1, 1).to(device)
166-
pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(-1, 1, 1).to(device)
186+
pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
187+
pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)
188+
189+
if self.resize_input:
190+
x = self.resize_longest_side(x)
191+
input_shape = x.shape[-2:]
167192

168193
x = (x - pixel_mean) / pixel_std
169194
h, w = x.shape[-2:]
170195
padh = self.encoder.img_size - h
171196
padw = self.encoder.img_size - w
172197
x = F.pad(x, (0, padw, 0, padh))
173-
return x
198+
return x, input_shape
174199

175200
def postprocess_masks(
176201
self,
@@ -189,10 +214,11 @@ def postprocess_masks(
189214
return masks
190215

191216
def forward(self, x):
192-
org_shape = x.shape[-2:]
217+
original_shape = x.shape[-2:]
193218

194-
# backbone used for reshaping inputs to the desired "encoder" shape
195-
x = torch.stack([self.preprocess(e) for e in x], dim=0)
219+
# Reshape the inputs to the shape expected by the encoder
220+
# and normalize the inputs if normalization is part of the model.
221+
x, input_shape = self.preprocess(x)
196222

197223
use_skip_connection = getattr(self, "use_skip_connection", True)
198224

@@ -236,7 +262,7 @@ def forward(self, x):
236262
if self.final_activation is not None:
237263
x = self.final_activation(x)
238264

239-
x = self.postprocess_masks(x, org_shape, org_shape)
265+
x = self.postprocess_masks(x, input_shape, original_shape)
240266
return x
241267

242268

0 commit comments

Comments
 (0)