@@ -48,7 +48,6 @@ def forward(
48
48
boxes : Optional [torch .Tensor ] = None ,
49
49
mask_input : Optional [torch .Tensor ] = None ,
50
50
multimask_output : bool = True ,
51
- img_idx : int = - 1 ,
52
51
):
53
52
assert high_res_feats [0 ].size () == (self .batch_size , 32 , 256 , 256 )
54
53
assert high_res_feats [1 ].size () == (self .batch_size , 64 , 128 , 128 )
@@ -73,7 +72,6 @@ def forward(
73
72
assert boxes is None
74
73
assert mask_input is None
75
74
assert multimask_output
76
- assert img_idx == - 1
77
75
if self .predictor is None :
78
76
assert self .aoti_compiled_model is not None
79
77
return self .aoti_compiled_model (
@@ -85,7 +83,6 @@ def forward(
85
83
boxes = boxes ,
86
84
mask_input = mask_input ,
87
85
multimask_output = multimask_output ,
88
- img_idx = img_idx ,
89
86
)
90
87
return self .predictor ._predict_masks (
91
88
high_res_feats ,
@@ -96,7 +93,6 @@ def forward(
96
93
boxes = boxes ,
97
94
mask_input = mask_input ,
98
95
multimask_output = multimask_output ,
99
- img_idx = img_idx ,
100
96
)
101
97
102
98
@@ -176,10 +172,137 @@ def export_model(
176
172
overwrite = overwrite ,
177
173
)
178
174
179
- print (f"{ task_type } cannot export _predict_masks" )
180
- return
175
+ if task_type in []:
176
+ example_input_args = ()
177
+ example_input_kwargs = {
178
+ "points" : (
179
+ torch .randn (
180
+ points_per_batch ,
181
+ 1 ,
182
+ 2 ,
183
+ dtype = torch .float32 ,
184
+ device = mask_generator .predictor .device ,
185
+ ),
186
+ torch .ones (
187
+ points_per_batch ,
188
+ 1 ,
189
+ dtype = torch .int32 ,
190
+ device = mask_generator .predictor .device ,
191
+ ),
192
+ ),
193
+ "boxes" : None ,
194
+ "masks" : None ,
195
+ }
196
+ aot_compile (
197
+ model_directory ,
198
+ "sam2_sam_prompt_encoder" ,
199
+ mask_generator .predictor .model .sam_prompt_encoder ,
200
+ example_input_args ,
201
+ sample_kwargs = example_input_kwargs ,
202
+ overwrite = overwrite ,
203
+ )
204
+
205
+ if task_type in []:
206
+ example_input_args = ()
207
+ example_input_kwargs = {
208
+ "image_embeddings" : torch .randn (
209
+ batch_size ,
210
+ 256 ,
211
+ 64 ,
212
+ 64 ,
213
+ dtype = torch .float32 ,
214
+ device = mask_generator .predictor .device ,
215
+ ),
216
+ "image_pe" : torch .randn (
217
+ batch_size ,
218
+ 256 ,
219
+ 64 ,
220
+ 64 ,
221
+ dtype = torch .float32 ,
222
+ device = mask_generator .predictor .device ,
223
+ ),
224
+ "sparse_prompt_embeddings" : torch .randn (
225
+ batch_size ,
226
+ 2 ,
227
+ 256 ,
228
+ dtype = torch .float32 ,
229
+ device = mask_generator .predictor .device ,
230
+ ),
231
+ "dense_prompt_embeddings" : torch .randn (
232
+ batch_size ,
233
+ 256 ,
234
+ 64 ,
235
+ 64 ,
236
+ dtype = torch .float32 ,
237
+ device = mask_generator .predictor .device ,
238
+ ),
239
+ "multimask_output" : True ,
240
+ "repeat_image" : False ,
241
+ "high_res_features" : [
242
+ torch .randn (
243
+ batch_size ,
244
+ 32 ,
245
+ 256 ,
246
+ 256 ,
247
+ dtype = mask_generator .predictor ._image_dtype ,
248
+ device = mask_generator .predictor .device ,
249
+ ),
250
+ torch .randn (
251
+ batch_size ,
252
+ 64 ,
253
+ 128 ,
254
+ 128 ,
255
+ dtype = mask_generator .predictor ._image_dtype ,
256
+ device = mask_generator .predictor .device ,
257
+ ),
258
+ ],
259
+ }
260
+ aot_compile (
261
+ model_directory ,
262
+ "sam2_sam_mask_decoder" ,
263
+ mask_generator .predictor .model .sam_mask_decoder ,
264
+ example_input_args ,
265
+ sample_kwargs = example_input_kwargs ,
266
+ overwrite = overwrite ,
267
+ )
268
+
269
+ if task_type in []:
270
+ example_input_args = (
271
+ torch .randn (
272
+ points_per_batch ,
273
+ 256 ,
274
+ 64 ,
275
+ 64 ,
276
+ dtype = mask_generator .predictor .model .sam_mask_decoder ._src_dtype ,
277
+ device = mask_generator .predictor .device ,
278
+ ),
279
+ torch .randn (
280
+ points_per_batch ,
281
+ 256 ,
282
+ 64 ,
283
+ 64 ,
284
+ dtype = mask_generator .predictor .model .sam_mask_decoder ._src_dtype ,
285
+ device = mask_generator .predictor .device ,
286
+ ),
287
+ torch .randn (
288
+ points_per_batch ,
289
+ 8 ,
290
+ 256 ,
291
+ dtype = mask_generator .predictor .model .sam_mask_decoder ._src_dtype ,
292
+ device = mask_generator .predictor .device ,
293
+ ),
294
+ )
295
+ example_input_kwargs = {}
296
+ aot_compile (
297
+ model_directory ,
298
+ "sam2_sam_mask_decoder_transformer" ,
299
+ mask_generator .predictor .model .sam_mask_decoder .transformer ,
300
+ example_input_args ,
301
+ sample_kwargs = example_input_kwargs ,
302
+ overwrite = overwrite ,
303
+ )
181
304
182
- if task_type in ["sps" ]:
305
+ if task_type in ["amg" , " sps" ]:
183
306
example_input_high_res_feats = [
184
307
torch .randn (
185
308
batch_size ,
@@ -239,7 +362,6 @@ def export_model(
239
362
"boxes" : None ,
240
363
"mask_input" : None ,
241
364
"multimask_output" : True ,
242
- "img_idx" : - 1 ,
243
365
}
244
366
245
367
sam2_image_predict_masks = SAM2ImagePredictor_predict_masks (
@@ -301,30 +423,54 @@ def load_exported_model(
301
423
pkg_m = LoadedModel (pkg )
302
424
mask_generator .predictor .model .image_encoder = pkg_m
303
425
304
- print (f"End load image encoder. Took { time .time () - t0 } s" )
305
- return mask_generator
306
-
307
- if task_type in ["amg" , "mps" ]:
426
+ if task_type in ["mps" ]:
308
427
return mask_generator
309
428
310
- path = Path (model_directory ) / Path ("sam2_image_predict_masks.pt2" )
311
- assert path .exists (), f"Expected { path } to exist"
312
- print (f"Start load from { path } " )
313
- pkg = torch ._inductor .aoti_load_package (str (path ))
314
- if task_type == "amg" :
315
- assert points_per_batch > 1
316
- if task_type == "sps" :
317
- assert points_per_batch == 1
318
- if task_type == "mps" :
319
- assert points_per_batch is None
320
- pkg_m = SAM2ImagePredictor_predict_masks (
321
- None ,
322
- batch_size = batch_size ,
323
- points_per_batch = points_per_batch ,
324
- aoti_compiled_model = pkg ,
325
- furious = furious ,
326
- )
327
- mask_generator .predictor ._predict_masks = pkg_m .forward
429
+ if task_type in []:
430
+ path = Path (model_directory ) / Path ("sam2_sam_prompt_encoder.pt2" )
431
+ assert path .exists (), f"Expected { path } to exist"
432
+ print (f"Start load from { path } " )
433
+ pkg = torch ._inductor .aoti_load_package (str (path ))
434
+ pkg_m = LoadedModel (pkg )
435
+ mask_generator .predictor .model .sam_prompt_encoder .forward = pkg_m .forward
436
+
437
+ if task_type in []:
438
+ path = Path (model_directory ) / Path ("sam2_sam_mask_decoder.pt2" )
439
+ assert path .exists (), f"Expected { path } to exist"
440
+ print (f"Start load from { path } " )
441
+ pkg = torch ._inductor .aoti_load_package (str (path ))
442
+ pkg_m = LoadedModel (pkg )
443
+ mask_generator .predictor .model .sam_mask_decoder .forward = pkg_m .forward
444
+
445
+ if task_type in []:
446
+ path = Path (model_directory ) / Path ("sam2_sam_mask_decoder_transformer.pt2" )
447
+ assert path .exists (), f"Expected { path } to exist"
448
+ print (f"Start load from { path } " )
449
+ pkg = torch ._inductor .aoti_load_package (str (path ))
450
+ pkg_m = LoadedModel (pkg )
451
+ mask_generator .predictor .model .sam_mask_decoder .transformer .forward = (
452
+ pkg_m .forward
453
+ )
454
+
455
+ if task_type in ["amg" , "sps" ]:
456
+ path = Path (model_directory ) / Path ("sam2_image_predict_masks.pt2" )
457
+ assert path .exists (), f"Expected { path } to exist"
458
+ print (f"Start load from { path } " )
459
+ pkg = torch ._inductor .aoti_load_package (str (path ))
460
+ if task_type == "amg" :
461
+ assert points_per_batch > 1
462
+ if task_type == "sps" :
463
+ assert points_per_batch == 1
464
+ if task_type == "mps" :
465
+ assert points_per_batch is None
466
+ pkg_m = SAM2ImagePredictor_predict_masks (
467
+ None ,
468
+ batch_size = batch_size ,
469
+ points_per_batch = points_per_batch ,
470
+ aoti_compiled_model = pkg ,
471
+ furious = furious ,
472
+ )
473
+ mask_generator .predictor ._predict_masks = pkg_m .forward
328
474
329
475
print (f"End load image encoder and predict masks. Took { time .time () - t0 } s" )
330
476
@@ -352,12 +498,13 @@ def set_fast(
352
498
dynamic = False ,
353
499
)
354
500
elif task_type == "amg" :
355
- mask_generator .predictor ._predict_masks = torch .compile (
356
- mask_generator .predictor ._predict_masks ,
357
- mode = "max-autotune" ,
358
- fullgraph = True ,
359
- dynamic = False ,
360
- )
501
+ if not loaded_exported_model :
502
+ mask_generator .predictor ._predict_masks = torch .compile (
503
+ mask_generator .predictor ._predict_masks ,
504
+ mode = "max-autotune" ,
505
+ fullgraph = True ,
506
+ dynamic = False ,
507
+ )
361
508
else :
362
509
# TODO: This might need to be under "allow_recompiles"
363
510
# mps encounters rapidly changing points per batch
0 commit comments