Skip to content

Mlazos/foreach map rc2.7 #3327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
af98f76
[DO NOT MERGE] 2.7 RC Test
svekars Mar 18, 2025
bf608af
Update .jenkins/build.sh
svekars Mar 18, 2025
dab6163
Update .jenkins/build.sh
svekars Mar 18, 2025
b183f0a
Update build.sh
svekars Mar 18, 2025
15216e7
Update build.sh
svekars Mar 18, 2025
20365bf
Update build.sh
svekars Mar 18, 2025
d5f8afb
Merge branch 'main' into svekars-patch-36
svekars Mar 19, 2025
e51b6e6
Update onnxscript in requirements (#3300)
justinchuby Mar 19, 2025
bf1f8d1
Merge branch 'main' into 2.7-RC-TEST
svekars Mar 19, 2025
eff088b
Update requirements.txt
svekars Mar 19, 2025
1fca477
Merge branch 'main' into 2.7-RC-TEST
svekars Mar 20, 2025
dc969fd
Update
svekars Mar 21, 2025
a6ff473
Merge branch 'main' into 2.7-RC-TEST
svekars Mar 21, 2025
3cdb01f
Update build.sh
svekars Mar 21, 2025
4b04c9b
Update .jenkins/validate_tutorials_built.py
svekars Mar 21, 2025
4aae3b1
Update build.sh
svekars Mar 21, 2025
327b32b
Update .jenkins/build.sh
svekars Mar 22, 2025
0b69e41
Update build.sh
svekars Mar 24, 2025
4221d54
Apply suggestions from code review
svekars Mar 24, 2025
68e58ae
Update build.sh
svekars Mar 24, 2025
69802e0
Update requirements.txt
svekars Mar 24, 2025
d463fd4
Update .jenkins/build.sh
svekars Mar 24, 2025
81efd5f
Update .jenkins/build.sh
svekars Mar 24, 2025
f7d8e7a
Fix the AOTI example (#3306)
desertfire Mar 25, 2025
f58cf37
Update build.sh
svekars Mar 26, 2025
6e3f90a
Disable rl tutorials again
svekars Mar 26, 2025
ea4e155
Merge branch 'main' into 2.7-RC-TEST
svekars Apr 15, 2025
95d5181
First commit
mlazos Apr 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ datasets
transformers
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
onnx
onnxscript
onnxscript>=0.2.2
onnxruntime
evaluate
accelerate>=0.20.1
Expand Down Expand Up @@ -69,5 +69,5 @@ pycocotools
semilearn==0.3.2
torchao==0.5.0
segment_anything==1.0
torchrec==1.0.0; platform_system == "Linux"
torchrec==1.1.0; platform_system == "Linux"
fbgemm-gpu==1.1.0; platform_system == "Linux"
12 changes: 8 additions & 4 deletions .jenkins/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ sudo apt-get install -y pandoc
#Install PyTorch Nightly for test.
# Nightly - pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
# Install 2.5 to merge all 2.4 PRs - uncomment to install nightly binaries (update the version as needed).
# sudo pip uninstall -y torch torchvision torchaudio torchtext torchdata
# sudo pip3 install torch==2.6.0 torchvision --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu124
# sudo pip uninstall -y fbgemm-gpu torchrec
# sudo pip3 install fbgemm-gpu==1.1.0 torchrec==1.0.0 --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu124
sudo pip uninstall -y torch torchvision torchaudio torchtext torchdata torchrl tensordict
pip3 install torch==2.7.0 torchvision torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu126
#sudo pip uninstall -y fbgemm-gpu
#sudo pip3 install --pre fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126/
#pip install tensordict-nightly
#pip install torchrl-nightly
#sudo pip3 install fbgemm-gpu==1.1.0 torchrec==1.0.0 --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu126


# Install two language tokenizers for Translation with TorchText tutorial
python -m spacy download en_core_web_sm
Expand Down
9 changes: 8 additions & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@
"intermediate_source/text_to_speech_with_torchaudio",
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
"advanced_source/semi_structured_sparse", # reenable after 3303 is fixed.
"recipes_source/recipes/reasoning_about_shapes"
"intermediate_source/mario_rl_tutorial", # reenable after 3302 is fixed
"intermediate_source/reinforcement_ppo", # reenable after 3302 is fixed
"intermediate_source/pinmem_nonblock", # reenable after 3302 is fixed
"intermediate_source/dqn_with_rnn_tutorial", # reenable after 3302 is fixed
"advanced_source/pendulum", # reenable after 3302 is fixed
"advanced_source/coding_ddpg", # reenable after 3302 is fixed
"intermediate_source/torchrec_intro_tutorial", # reenable after 3302 is fixed
"recipes_source/recipes/reasoning_about_shapes" # reenable after 3326 is fixed
]

def tutorial_source_dirs() -> List[Path]:
Expand Down
2 changes: 1 addition & 1 deletion intermediate_source/torch_export_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def forward(self, x):
# with torch.no_grad():
# pt2_path = torch._inductor.aoti_compile_and_package(ep)
#
# # Load and run the .so file in Python.
# # Load and run the .pt2 file in Python.
# # To load and run it in a C++ environment, see:
# # https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
# aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
Expand Down
198 changes: 198 additions & 0 deletions recipes_source/foreach_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
(beta) Explicit horizontal fusion with foreach_map and torch.compile
============================================================

**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""

#########################################################
# Horizontal fusion is a key optimization in ML compilers. In eager,
# this is typically expressed using the torch._foreach* ops which parallelizes
# operations across a list of tensors. However, supporting all possible permutations
# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map
# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach
# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer
# with ``foreach_map`` to generate a fully fused kernel.
#
#
# .. note::
#
# This tutorial requires PyTorch 2.7.0 or later.

#####################################################################
# Model Setup
# ~~~~~~~~~~~~~~~~~~~~~
# For this example, we'll use a simple sequence of linear layers.
# We instantiate an independent copy to compare the two optimizer implementations.
#
import torch

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)

# Create simple model
model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
model_copy = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")

# run forward pass
output = model(input)
output_copy = model_copy(input)

# run backward to populate the grads for our optimizer below
output.sum().backward()
output_copy.sum().backward()

#####################################################################
# Helper functions for foreach_map implementation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In this section, we'll begin our implementation of the Adam optimizer.
#
from torch._higher_order_ops.foreach_map import foreach_map

# Helper function to extract optimizer states from a torch.optim.Adam instance
def get_inputs(optim):
steps = []
params = []
grads = []
exp_avgs = []
exp_avg_sqs = []
for group in optim.param_groups:
for p in group["params"]:
params.append(p)
grads.append(p.grad)
state = optim.state[p]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
steps.append(state["step"])

return steps, params, exp_avgs, exp_avg_sqs


# Functions to update the different optimizer states
def update_exp_avg_sq(exp_avg_sq, grad, beta2):
return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)

def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):
bias_correction1 = 1 - torch.pow(beta1, step)
bias_correction2 = (1 - torch.pow(beta2, step)).sqrt()
step_size = (lr / bias_correction1).neg()
denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)
return torch.add(param, torch.div(exp_avg, denom))

# Our full Adam implementation
def foreach_map_adam(
steps,
params,
exp_avgs,
exp_avg_sqs,
weight_decay=0,
beta1=0.9,
beta2=0.999,
lr=1e-3,
eps=1e-8,
):
with torch.no_grad():
grads = [param.grad for param in params]
# update step
updated_steps = foreach_map(lambda x: x + 1, steps)
torch._foreach_copy_(steps, updated_steps)

if weight_decay != 0:
foreach_map(torch.add, (grads,), alpha=weight_decay)

# Higher-order operators (HOPs) cannot have multiple outputs at the moment
# need to call foreach_map once for each output
exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1)
exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2)
params_updated = foreach_map(
update_param,
params,
steps,
exp_avgs_updated,
exp_avgs_sq_updated,
beta1,
beta2,
lr,
eps,
)
# Higher-order operators (HOPs) don't support input mutation today
# so manually update the states in-place
torch._foreach_copy_(exp_avgs, exp_avgs_updated)
torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated)
torch._foreach_copy_(params, params_updated)
return

#####################################################################
# Setting up and running the compiled kernel
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In this section, we'll run our Adam optimizer
# and compare the results
#
# .. note::
#
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01))

# warm up the optimizer state dict
opt_eager.step()
opt_eager_copy.step()

inputs = get_inputs(opt_eager_copy)
compiled_adam = torch.compile(foreach_map_adam)

# optionally view the output code
torch._logging.set_logs(output_code=True)

# Warmup runs to compile the function
for _ in range(5):
opt_eager.step()
compiled_adam(*inputs)

for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]):
torch.allclose(eager_p, compile_p)

# Benchmark performance

# Let's define a helpful benchmarking function:
import torch.utils.benchmark as benchmark

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6

eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step)
compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs))

assert eager_runtime > compiled_runtime

print(f"eager runtime: {eager_runtime}us")
print(f"compiled runtime: {compiled_runtime}us")



######################################################################
# Conclusion
# ~~~~~~~~~~
# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map.
# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam
# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide
# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a
# valuable resource for developers looking to improve the performance of their models with horizontal fusion.
#
# See also:
#
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.
9 changes: 9 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/compiling_optimizer_lr_scheduler.html
:tags: Model-Optimization

.. (beta) Explicit horizontal fusion with foreach_map and torch.compile

.. customcarditem::
:header: (beta) Explicit horizontal fusion with foreach_map and torch.compile
:card_description: Horizontally fuse pointwise ops with torch.compile
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/foreach_map.py
:tags: Model-Optimization

.. Using User-Defined Triton Kernels with ``torch.compile``

.. customcarditem::
Expand Down
8 changes: 4 additions & 4 deletions recipes_source/torch_export_aoti_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
model_path = os.path.join(os.getcwd(), "resnet18.pt2")

compiled_model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
example_inputs = torch.randn(2, 3, 224, 224, device=device)

with torch.inference_mode():
output = compiled_model(example_inputs)
Expand Down Expand Up @@ -238,11 +238,11 @@ def timed(fn):

torch._dynamo.reset()

model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
compiled_model = torch._inductor.aoti_load_package(model_path)
example_inputs = torch.randn(1, 3, 224, 224, device=device)

with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
_, time_taken = timed(lambda: compiled_model(example_inputs))
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")


Expand Down