Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into tensor_subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 6, 2025
2 parents 8b6da86 + 867a91f commit 9eb890c
Show file tree
Hide file tree
Showing 17 changed files with 591 additions and 367 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_wheels_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
password: ${{ secrets.TORCHAO_NOTIFY_PASSWORD }}
from: [email protected]
to: ${{ secrets.TORCHAO_NOTIFY_RECIPIENT }}
subject: breakbutterflyScheduled Build Failure for TorchAO
subject: Scheduled Build Failure for TorchAO
body: |
Build Failure Notification for TorchAO
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/torchao_experimental_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ jobs:
conda activate venv
pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104"
pip install numpy
pip install pytest
USE_CPP=1 pip install .
- name: Run tests
run: |
conda activate venv
python torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
219 changes: 183 additions & 36 deletions examples/sam2_amg_server/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def forward(
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
img_idx: int = -1,
):
assert high_res_feats[0].size() == (self.batch_size, 32, 256, 256)
assert high_res_feats[1].size() == (self.batch_size, 64, 128, 128)
Expand All @@ -73,7 +72,6 @@ def forward(
assert boxes is None
assert mask_input is None
assert multimask_output
assert img_idx == -1
if self.predictor is None:
assert self.aoti_compiled_model is not None
return self.aoti_compiled_model(
Expand All @@ -85,7 +83,6 @@ def forward(
boxes=boxes,
mask_input=mask_input,
multimask_output=multimask_output,
img_idx=img_idx,
)
return self.predictor._predict_masks(
high_res_feats,
Expand All @@ -96,7 +93,6 @@ def forward(
boxes=boxes,
mask_input=mask_input,
multimask_output=multimask_output,
img_idx=img_idx,
)


Expand Down Expand Up @@ -176,10 +172,137 @@ def export_model(
overwrite=overwrite,
)

print(f"{task_type} cannot export _predict_masks")
return
if task_type in []:
example_input_args = ()
example_input_kwargs = {
"points": (
torch.randn(
points_per_batch,
1,
2,
dtype=torch.float32,
device=mask_generator.predictor.device,
),
torch.ones(
points_per_batch,
1,
dtype=torch.int32,
device=mask_generator.predictor.device,
),
),
"boxes": None,
"masks": None,
}
aot_compile(
model_directory,
"sam2_sam_prompt_encoder",
mask_generator.predictor.model.sam_prompt_encoder,
example_input_args,
sample_kwargs=example_input_kwargs,
overwrite=overwrite,
)

if task_type in []:
example_input_args = ()
example_input_kwargs = {
"image_embeddings": torch.randn(
batch_size,
256,
64,
64,
dtype=torch.float32,
device=mask_generator.predictor.device,
),
"image_pe": torch.randn(
batch_size,
256,
64,
64,
dtype=torch.float32,
device=mask_generator.predictor.device,
),
"sparse_prompt_embeddings": torch.randn(
batch_size,
2,
256,
dtype=torch.float32,
device=mask_generator.predictor.device,
),
"dense_prompt_embeddings": torch.randn(
batch_size,
256,
64,
64,
dtype=torch.float32,
device=mask_generator.predictor.device,
),
"multimask_output": True,
"repeat_image": False,
"high_res_features": [
torch.randn(
batch_size,
32,
256,
256,
dtype=mask_generator.predictor._image_dtype,
device=mask_generator.predictor.device,
),
torch.randn(
batch_size,
64,
128,
128,
dtype=mask_generator.predictor._image_dtype,
device=mask_generator.predictor.device,
),
],
}
aot_compile(
model_directory,
"sam2_sam_mask_decoder",
mask_generator.predictor.model.sam_mask_decoder,
example_input_args,
sample_kwargs=example_input_kwargs,
overwrite=overwrite,
)

if task_type in []:
example_input_args = (
torch.randn(
points_per_batch,
256,
64,
64,
dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype,
device=mask_generator.predictor.device,
),
torch.randn(
points_per_batch,
256,
64,
64,
dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype,
device=mask_generator.predictor.device,
),
torch.randn(
points_per_batch,
8,
256,
dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype,
device=mask_generator.predictor.device,
),
)
example_input_kwargs = {}
aot_compile(
model_directory,
"sam2_sam_mask_decoder_transformer",
mask_generator.predictor.model.sam_mask_decoder.transformer,
example_input_args,
sample_kwargs=example_input_kwargs,
overwrite=overwrite,
)

if task_type in ["sps"]:
if task_type in ["amg", "sps"]:
example_input_high_res_feats = [
torch.randn(
batch_size,
Expand Down Expand Up @@ -239,7 +362,6 @@ def export_model(
"boxes": None,
"mask_input": None,
"multimask_output": True,
"img_idx": -1,
}

sam2_image_predict_masks = SAM2ImagePredictor_predict_masks(
Expand Down Expand Up @@ -301,30 +423,54 @@ def load_exported_model(
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.image_encoder = pkg_m

print(f"End load image encoder. Took {time.time() - t0}s")
return mask_generator

if task_type in ["amg", "mps"]:
if task_type in ["mps"]:
return mask_generator

path = Path(model_directory) / Path("sam2_image_predict_masks.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
if task_type == "amg":
assert points_per_batch > 1
if task_type == "sps":
assert points_per_batch == 1
if task_type == "mps":
assert points_per_batch is None
pkg_m = SAM2ImagePredictor_predict_masks(
None,
batch_size=batch_size,
points_per_batch=points_per_batch,
aoti_compiled_model=pkg,
furious=furious,
)
mask_generator.predictor._predict_masks = pkg_m.forward
if task_type in []:
path = Path(model_directory) / Path("sam2_sam_prompt_encoder.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.sam_prompt_encoder.forward = pkg_m.forward

if task_type in []:
path = Path(model_directory) / Path("sam2_sam_mask_decoder.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.sam_mask_decoder.forward = pkg_m.forward

if task_type in []:
path = Path(model_directory) / Path("sam2_sam_mask_decoder_transformer.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.sam_mask_decoder.transformer.forward = (
pkg_m.forward
)

if task_type in ["amg", "sps"]:
path = Path(model_directory) / Path("sam2_image_predict_masks.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
if task_type == "amg":
assert points_per_batch > 1
if task_type == "sps":
assert points_per_batch == 1
if task_type == "mps":
assert points_per_batch is None
pkg_m = SAM2ImagePredictor_predict_masks(
None,
batch_size=batch_size,
points_per_batch=points_per_batch,
aoti_compiled_model=pkg,
furious=furious,
)
mask_generator.predictor._predict_masks = pkg_m.forward

print(f"End load image encoder and predict masks. Took {time.time() - t0}s")

Expand Down Expand Up @@ -352,12 +498,13 @@ def set_fast(
dynamic=False,
)
elif task_type == "amg":
mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
if not loaded_exported_model:
mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
else:
# TODO: This might need to be under "allow_recompiles"
# mps encounters rapidly changing points per batch
Expand Down
54 changes: 45 additions & 9 deletions examples/sam2_amg_server/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,38 @@
from tqdm import tqdm


def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result


def memory_runner(path, fn, *args, **kwargs):
print("Start memory recording")
torch.cuda.synchronize()
torch.cuda.memory._record_memory_history(
True, trace_alloc_max_entries=100000, trace_alloc_record_context=True
)
result = fn(*args, **kwargs)
torch.cuda.synchronize()
snapshot = torch.cuda.memory._snapshot()
print("Finish memory recording")
import pickle

with open(path, "wb") as f:
pickle.dump(snapshot, f)
# Use to convert pickle file into html
# python torch/cuda/_memory_viz.py trace_plot <snapshot>.pickle -o <snapshot>.html
return result


def latencies_statistics(data):
# Convert the list to a NumPy array
data_array = np.array(data)
Expand Down Expand Up @@ -330,16 +362,17 @@ def decode_img_bytes(img_bytes_tensors, gpu_preproc, baseline):
for img_bytes_tensor in img_bytes_tensors:
with record_function("decode image bytes"):
if gpu_preproc:
# NOTE: We have to use numpy for the baseline
assert not baseline
from torchvision import io as tio

image_tensor = tio.decode_jpeg(
img_bytes_tensor, device="cuda", mode=tio.ImageReadMode.RGB
)
from torchvision.transforms.v2 import functional as F
image_tensor = file_bytes_to_image_tensor(img_bytes_tensor)
from torchvision.transforms import ToTensor, v2

image_tensor = F.to_dtype(image_tensor, torch.float32, scale=True)
if not baseline:
image_tensor = torch.from_numpy(image_tensor)
image_tensor = image_tensor.permute((2, 0, 1))
image_tensor = image_tensor.cuda()
with record_function("v2.ToDtype"):
image_tensor = v2.ToDtype(torch.float32, scale=True)(
image_tensor
)
else:
image_tensor = file_bytes_to_image_tensor(img_bytes_tensor)
from torchvision.transforms import ToTensor
Expand Down Expand Up @@ -431,6 +464,7 @@ def main(
quiet=False,
gpu_preproc=False,
batch_size=1,
seed=42,
):
if batch_size <= 0:
raise ValueError("Expected --batch_size to be at least 1 but got {batch_size}")
Expand Down Expand Up @@ -502,6 +536,7 @@ def main(
from torchao._models.sam2.utils.amg import (
mask_to_rle_pytorch_2 as mask_to_rle_pytorch,
)
torch.manual_seed(seed)
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
Expand Down Expand Up @@ -628,4 +663,5 @@ def main(
main.__doc__ = main_docstring()
if __name__ == "__main__":
# profiler_runner("asdf.json.gz", fire.Fire, main)
# memory_runner("asdf.pickle", fire.Fire, main)
fire.Fire(main)
2 changes: 1 addition & 1 deletion examples/sam2_amg_server/reproduce_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run(task, output_path: Path, kwargs, baseline_folder=None, environ=None):
stdout, stderr = run_script_with_args(
[
"generate_data.py",
"~/checkpoints/sam2",
f"{str(Path.home())}/checkpoints/sam2",
"large",
task,
image_paths,
Expand Down
Loading

0 comments on commit 9eb890c

Please sign in to comment.