@@ -36,8 +36,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
36
36
)
37
37
encoder_state = model .image_encoder .state_dict ()
38
38
except Exception :
39
- # If we have a MAE encoder, then we directly load the encoder state
40
- # from the checkpoint.
39
+ # Try loading the encoder state directly from a checkpoint.
41
40
encoder_state = torch .load (checkpoint )
42
41
43
42
elif backbone == "mae" :
@@ -68,16 +67,18 @@ def __init__(
68
67
out_channels : int = 1 ,
69
68
use_sam_stats : bool = False ,
70
69
use_mae_stats : bool = False ,
70
+ resize_input : bool = True ,
71
71
encoder_checkpoint : Optional [Union [str , OrderedDict ]] = None ,
72
72
final_activation : Optional [Union [str , nn .Module ]] = None ,
73
73
use_skip_connection : bool = True ,
74
- embed_dim : Optional [int ] = None
74
+ embed_dim : Optional [int ] = None ,
75
75
) -> None :
76
76
super ().__init__ ()
77
77
78
78
self .use_sam_stats = use_sam_stats
79
79
self .use_mae_stats = use_mae_stats
80
80
self .use_skip_connection = use_skip_connection
81
+ self .resize_input = resize_input
81
82
82
83
if isinstance (encoder , str ): # "vit_b" / "vit_l" / "vit_h"
83
84
print (f"Using { encoder } from { backbone .upper ()} " )
@@ -152,25 +153,49 @@ def _get_activation(self, activation):
152
153
raise ValueError (f"Invalid activation: { activation } " )
153
154
return return_activation ()
154
155
156
+ @staticmethod
157
+ def get_preprocess_shape (oldh : int , oldw : int , long_side_length : int ) -> Tuple [int , int ]:
158
+ """Compute the output size given input size and target long side length.
159
+ """
160
+ scale = long_side_length * 1.0 / max (oldh , oldw )
161
+ newh , neww = oldh * scale , oldw * scale
162
+ neww = int (neww + 0.5 )
163
+ newh = int (newh + 0.5 )
164
+ return (newh , neww )
165
+
166
+ def resize_longest_side (self , image : torch .Tensor ) -> torch .Tensor :
167
+ """Resizes the image so that the longest side has the correct length.
168
+
169
+ Expects batched images with shape BxCxHxW and float format.
170
+ """
171
+ target_size = self .get_preprocess_shape (image .shape [2 ], image .shape [3 ], self .encoder .img_size )
172
+ return F .interpolate (
173
+ image , target_size , mode = "bilinear" , align_corners = False , antialias = True
174
+ )
175
+
155
176
def preprocess (self , x : torch .Tensor ) -> torch .Tensor :
156
- device = "cuda" if torch . cuda . is_available () else "cpu"
177
+ device = x . device
157
178
158
179
if self .use_sam_stats :
159
- pixel_mean = torch .Tensor ([123.675 , 116.28 , 103.53 ]).view (- 1 , 1 , 1 ).to (device )
160
- pixel_std = torch .Tensor ([58.395 , 57.12 , 57.375 ]).view (- 1 , 1 , 1 ).to (device )
180
+ pixel_mean = torch .Tensor ([123.675 , 116.28 , 103.53 ]).view (1 , - 1 , 1 , 1 ).to (device )
181
+ pixel_std = torch .Tensor ([58.395 , 57.12 , 57.375 ]).view (1 , - 1 , 1 , 1 ).to (device )
161
182
elif self .use_mae_stats :
162
183
# TODO: add mean std from mae experiments (or open up arguments for this)
163
184
raise NotImplementedError
164
185
else :
165
- pixel_mean = torch .Tensor ([0.0 , 0.0 , 0.0 ]).view (- 1 , 1 , 1 ).to (device )
166
- pixel_std = torch .Tensor ([1.0 , 1.0 , 1.0 ]).view (- 1 , 1 , 1 ).to (device )
186
+ pixel_mean = torch .Tensor ([0.0 , 0.0 , 0.0 ]).view (1 , - 1 , 1 , 1 ).to (device )
187
+ pixel_std = torch .Tensor ([1.0 , 1.0 , 1.0 ]).view (1 , - 1 , 1 , 1 ).to (device )
188
+
189
+ if self .resize_input :
190
+ x = self .resize_longest_side (x )
191
+ input_shape = x .shape [- 2 :]
167
192
168
193
x = (x - pixel_mean ) / pixel_std
169
194
h , w = x .shape [- 2 :]
170
195
padh = self .encoder .img_size - h
171
196
padw = self .encoder .img_size - w
172
197
x = F .pad (x , (0 , padw , 0 , padh ))
173
- return x
198
+ return x , input_shape
174
199
175
200
def postprocess_masks (
176
201
self ,
@@ -189,10 +214,11 @@ def postprocess_masks(
189
214
return masks
190
215
191
216
def forward (self , x ):
192
- org_shape = x .shape [- 2 :]
217
+ original_shape = x .shape [- 2 :]
193
218
194
- # backbone used for reshaping inputs to the desired "encoder" shape
195
- x = torch .stack ([self .preprocess (e ) for e in x ], dim = 0 )
219
+ # Reshape the inputs to the shape expected by the encoder
220
+ # and normalize the inputs if normalization is part of the model.
221
+ x , input_shape = self .preprocess (x )
196
222
197
223
use_skip_connection = getattr (self , "use_skip_connection" , True )
198
224
@@ -236,7 +262,7 @@ def forward(self, x):
236
262
if self .final_activation is not None :
237
263
x = self .final_activation (x )
238
264
239
- x = self .postprocess_masks (x , org_shape , org_shape )
265
+ x = self .postprocess_masks (x , input_shape , original_shape )
240
266
return x
241
267
242
268
0 commit comments