Skip to content

Commit 8d14f0e

Browse files
authored
SAM2: more export, small perf improvements (#1673)
1 parent 8afd10e commit 8d14f0e

File tree

8 files changed

+326
-131
lines changed

8 files changed

+326
-131
lines changed

examples/sam2_amg_server/compile_export_utils.py

Lines changed: 183 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def forward(
4848
boxes: Optional[torch.Tensor] = None,
4949
mask_input: Optional[torch.Tensor] = None,
5050
multimask_output: bool = True,
51-
img_idx: int = -1,
5251
):
5352
assert high_res_feats[0].size() == (self.batch_size, 32, 256, 256)
5453
assert high_res_feats[1].size() == (self.batch_size, 64, 128, 128)
@@ -73,7 +72,6 @@ def forward(
7372
assert boxes is None
7473
assert mask_input is None
7574
assert multimask_output
76-
assert img_idx == -1
7775
if self.predictor is None:
7876
assert self.aoti_compiled_model is not None
7977
return self.aoti_compiled_model(
@@ -85,7 +83,6 @@ def forward(
8583
boxes=boxes,
8684
mask_input=mask_input,
8785
multimask_output=multimask_output,
88-
img_idx=img_idx,
8986
)
9087
return self.predictor._predict_masks(
9188
high_res_feats,
@@ -96,7 +93,6 @@ def forward(
9693
boxes=boxes,
9794
mask_input=mask_input,
9895
multimask_output=multimask_output,
99-
img_idx=img_idx,
10096
)
10197

10298

@@ -176,10 +172,137 @@ def export_model(
176172
overwrite=overwrite,
177173
)
178174

179-
print(f"{task_type} cannot export _predict_masks")
180-
return
175+
if task_type in []:
176+
example_input_args = ()
177+
example_input_kwargs = {
178+
"points": (
179+
torch.randn(
180+
points_per_batch,
181+
1,
182+
2,
183+
dtype=torch.float32,
184+
device=mask_generator.predictor.device,
185+
),
186+
torch.ones(
187+
points_per_batch,
188+
1,
189+
dtype=torch.int32,
190+
device=mask_generator.predictor.device,
191+
),
192+
),
193+
"boxes": None,
194+
"masks": None,
195+
}
196+
aot_compile(
197+
model_directory,
198+
"sam2_sam_prompt_encoder",
199+
mask_generator.predictor.model.sam_prompt_encoder,
200+
example_input_args,
201+
sample_kwargs=example_input_kwargs,
202+
overwrite=overwrite,
203+
)
204+
205+
if task_type in []:
206+
example_input_args = ()
207+
example_input_kwargs = {
208+
"image_embeddings": torch.randn(
209+
batch_size,
210+
256,
211+
64,
212+
64,
213+
dtype=torch.float32,
214+
device=mask_generator.predictor.device,
215+
),
216+
"image_pe": torch.randn(
217+
batch_size,
218+
256,
219+
64,
220+
64,
221+
dtype=torch.float32,
222+
device=mask_generator.predictor.device,
223+
),
224+
"sparse_prompt_embeddings": torch.randn(
225+
batch_size,
226+
2,
227+
256,
228+
dtype=torch.float32,
229+
device=mask_generator.predictor.device,
230+
),
231+
"dense_prompt_embeddings": torch.randn(
232+
batch_size,
233+
256,
234+
64,
235+
64,
236+
dtype=torch.float32,
237+
device=mask_generator.predictor.device,
238+
),
239+
"multimask_output": True,
240+
"repeat_image": False,
241+
"high_res_features": [
242+
torch.randn(
243+
batch_size,
244+
32,
245+
256,
246+
256,
247+
dtype=mask_generator.predictor._image_dtype,
248+
device=mask_generator.predictor.device,
249+
),
250+
torch.randn(
251+
batch_size,
252+
64,
253+
128,
254+
128,
255+
dtype=mask_generator.predictor._image_dtype,
256+
device=mask_generator.predictor.device,
257+
),
258+
],
259+
}
260+
aot_compile(
261+
model_directory,
262+
"sam2_sam_mask_decoder",
263+
mask_generator.predictor.model.sam_mask_decoder,
264+
example_input_args,
265+
sample_kwargs=example_input_kwargs,
266+
overwrite=overwrite,
267+
)
268+
269+
if task_type in []:
270+
example_input_args = (
271+
torch.randn(
272+
points_per_batch,
273+
256,
274+
64,
275+
64,
276+
dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype,
277+
device=mask_generator.predictor.device,
278+
),
279+
torch.randn(
280+
points_per_batch,
281+
256,
282+
64,
283+
64,
284+
dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype,
285+
device=mask_generator.predictor.device,
286+
),
287+
torch.randn(
288+
points_per_batch,
289+
8,
290+
256,
291+
dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype,
292+
device=mask_generator.predictor.device,
293+
),
294+
)
295+
example_input_kwargs = {}
296+
aot_compile(
297+
model_directory,
298+
"sam2_sam_mask_decoder_transformer",
299+
mask_generator.predictor.model.sam_mask_decoder.transformer,
300+
example_input_args,
301+
sample_kwargs=example_input_kwargs,
302+
overwrite=overwrite,
303+
)
181304

182-
if task_type in ["sps"]:
305+
if task_type in ["amg", "sps"]:
183306
example_input_high_res_feats = [
184307
torch.randn(
185308
batch_size,
@@ -239,7 +362,6 @@ def export_model(
239362
"boxes": None,
240363
"mask_input": None,
241364
"multimask_output": True,
242-
"img_idx": -1,
243365
}
244366

245367
sam2_image_predict_masks = SAM2ImagePredictor_predict_masks(
@@ -301,30 +423,54 @@ def load_exported_model(
301423
pkg_m = LoadedModel(pkg)
302424
mask_generator.predictor.model.image_encoder = pkg_m
303425

304-
print(f"End load image encoder. Took {time.time() - t0}s")
305-
return mask_generator
306-
307-
if task_type in ["amg", "mps"]:
426+
if task_type in ["mps"]:
308427
return mask_generator
309428

310-
path = Path(model_directory) / Path("sam2_image_predict_masks.pt2")
311-
assert path.exists(), f"Expected {path} to exist"
312-
print(f"Start load from {path}")
313-
pkg = torch._inductor.aoti_load_package(str(path))
314-
if task_type == "amg":
315-
assert points_per_batch > 1
316-
if task_type == "sps":
317-
assert points_per_batch == 1
318-
if task_type == "mps":
319-
assert points_per_batch is None
320-
pkg_m = SAM2ImagePredictor_predict_masks(
321-
None,
322-
batch_size=batch_size,
323-
points_per_batch=points_per_batch,
324-
aoti_compiled_model=pkg,
325-
furious=furious,
326-
)
327-
mask_generator.predictor._predict_masks = pkg_m.forward
429+
if task_type in []:
430+
path = Path(model_directory) / Path("sam2_sam_prompt_encoder.pt2")
431+
assert path.exists(), f"Expected {path} to exist"
432+
print(f"Start load from {path}")
433+
pkg = torch._inductor.aoti_load_package(str(path))
434+
pkg_m = LoadedModel(pkg)
435+
mask_generator.predictor.model.sam_prompt_encoder.forward = pkg_m.forward
436+
437+
if task_type in []:
438+
path = Path(model_directory) / Path("sam2_sam_mask_decoder.pt2")
439+
assert path.exists(), f"Expected {path} to exist"
440+
print(f"Start load from {path}")
441+
pkg = torch._inductor.aoti_load_package(str(path))
442+
pkg_m = LoadedModel(pkg)
443+
mask_generator.predictor.model.sam_mask_decoder.forward = pkg_m.forward
444+
445+
if task_type in []:
446+
path = Path(model_directory) / Path("sam2_sam_mask_decoder_transformer.pt2")
447+
assert path.exists(), f"Expected {path} to exist"
448+
print(f"Start load from {path}")
449+
pkg = torch._inductor.aoti_load_package(str(path))
450+
pkg_m = LoadedModel(pkg)
451+
mask_generator.predictor.model.sam_mask_decoder.transformer.forward = (
452+
pkg_m.forward
453+
)
454+
455+
if task_type in ["amg", "sps"]:
456+
path = Path(model_directory) / Path("sam2_image_predict_masks.pt2")
457+
assert path.exists(), f"Expected {path} to exist"
458+
print(f"Start load from {path}")
459+
pkg = torch._inductor.aoti_load_package(str(path))
460+
if task_type == "amg":
461+
assert points_per_batch > 1
462+
if task_type == "sps":
463+
assert points_per_batch == 1
464+
if task_type == "mps":
465+
assert points_per_batch is None
466+
pkg_m = SAM2ImagePredictor_predict_masks(
467+
None,
468+
batch_size=batch_size,
469+
points_per_batch=points_per_batch,
470+
aoti_compiled_model=pkg,
471+
furious=furious,
472+
)
473+
mask_generator.predictor._predict_masks = pkg_m.forward
328474

329475
print(f"End load image encoder and predict masks. Took {time.time() - t0}s")
330476

@@ -352,12 +498,13 @@ def set_fast(
352498
dynamic=False,
353499
)
354500
elif task_type == "amg":
355-
mask_generator.predictor._predict_masks = torch.compile(
356-
mask_generator.predictor._predict_masks,
357-
mode="max-autotune",
358-
fullgraph=True,
359-
dynamic=False,
360-
)
501+
if not loaded_exported_model:
502+
mask_generator.predictor._predict_masks = torch.compile(
503+
mask_generator.predictor._predict_masks,
504+
mode="max-autotune",
505+
fullgraph=True,
506+
dynamic=False,
507+
)
361508
else:
362509
# TODO: This might need to be under "allow_recompiles"
363510
# mps encounters rapidly changing points per batch

examples/sam2_amg_server/generate_data.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,38 @@
2121
from tqdm import tqdm
2222

2323

24+
def profiler_runner(path, fn, *args, **kwargs):
25+
with torch.profiler.profile(
26+
activities=[
27+
torch.profiler.ProfilerActivity.CPU,
28+
torch.profiler.ProfilerActivity.CUDA,
29+
],
30+
record_shapes=True,
31+
) as prof:
32+
result = fn(*args, **kwargs)
33+
prof.export_chrome_trace(path)
34+
return result
35+
36+
37+
def memory_runner(path, fn, *args, **kwargs):
38+
print("Start memory recording")
39+
torch.cuda.synchronize()
40+
torch.cuda.memory._record_memory_history(
41+
True, trace_alloc_max_entries=100000, trace_alloc_record_context=True
42+
)
43+
result = fn(*args, **kwargs)
44+
torch.cuda.synchronize()
45+
snapshot = torch.cuda.memory._snapshot()
46+
print("Finish memory recording")
47+
import pickle
48+
49+
with open(path, "wb") as f:
50+
pickle.dump(snapshot, f)
51+
# Use to convert pickle file into html
52+
# python torch/cuda/_memory_viz.py trace_plot <snapshot>.pickle -o <snapshot>.html
53+
return result
54+
55+
2456
def latencies_statistics(data):
2557
# Convert the list to a NumPy array
2658
data_array = np.array(data)
@@ -330,16 +362,17 @@ def decode_img_bytes(img_bytes_tensors, gpu_preproc, baseline):
330362
for img_bytes_tensor in img_bytes_tensors:
331363
with record_function("decode image bytes"):
332364
if gpu_preproc:
333-
# NOTE: We have to use numpy for the baseline
334-
assert not baseline
335-
from torchvision import io as tio
336-
337-
image_tensor = tio.decode_jpeg(
338-
img_bytes_tensor, device="cuda", mode=tio.ImageReadMode.RGB
339-
)
340-
from torchvision.transforms.v2 import functional as F
365+
image_tensor = file_bytes_to_image_tensor(img_bytes_tensor)
366+
from torchvision.transforms import ToTensor, v2
341367

342-
image_tensor = F.to_dtype(image_tensor, torch.float32, scale=True)
368+
if not baseline:
369+
image_tensor = torch.from_numpy(image_tensor)
370+
image_tensor = image_tensor.permute((2, 0, 1))
371+
image_tensor = image_tensor.cuda()
372+
with record_function("v2.ToDtype"):
373+
image_tensor = v2.ToDtype(torch.float32, scale=True)(
374+
image_tensor
375+
)
343376
else:
344377
image_tensor = file_bytes_to_image_tensor(img_bytes_tensor)
345378
from torchvision.transforms import ToTensor
@@ -431,6 +464,7 @@ def main(
431464
quiet=False,
432465
gpu_preproc=False,
433466
batch_size=1,
467+
seed=42,
434468
):
435469
if batch_size <= 0:
436470
raise ValueError("Expected --batch_size to be at least 1 but got {batch_size}")
@@ -502,6 +536,7 @@ def main(
502536
from torchao._models.sam2.utils.amg import (
503537
mask_to_rle_pytorch_2 as mask_to_rle_pytorch,
504538
)
539+
torch.manual_seed(seed)
505540
device = "cuda"
506541
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
507542
if verbose:
@@ -628,4 +663,5 @@ def main(
628663
main.__doc__ = main_docstring()
629664
if __name__ == "__main__":
630665
# profiler_runner("asdf.json.gz", fire.Fire, main)
666+
# memory_runner("asdf.pickle", fire.Fire, main)
631667
fire.Fire(main)

examples/sam2_amg_server/reproduce_experiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def run(task, output_path: Path, kwargs, baseline_folder=None, environ=None):
8989
stdout, stderr = run_script_with_args(
9090
[
9191
"generate_data.py",
92-
"~/checkpoints/sam2",
92+
f"{str(Path.home())}/checkpoints/sam2",
9393
"large",
9494
task,
9595
image_paths,

0 commit comments

Comments
 (0)