Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 701647b

Browse files
Andrew Gufacebook-github-bot
Andrew Gu
authored andcommitted
Reduced CPU overhead in precompute_float8_dynamic_scale_for_fsdp (#331)
Summary: Pull Request resolved: #331 **Description** For Llama3-8B on 8xH100 profiling with `with_stack=True` (which does add overhead), the `precompute_float8_dynamic_scale_for_fsdp` CPU time decreases from 24 ms to 15 ms. Before: <img width="600" alt="Screenshot 2024-07-25 at 10 16 38 AM" src="https://github.com/user-attachments/assets/5d2384a0-6864-4bdc-91db-90cae809c702"> After: <img width="638" alt="Screenshot 2024-07-25 at 10 17 00 AM" src="https://github.com/user-attachments/assets/1dbf3b2e-a576-4cdf-ac4f-06ae96020c38"> **Test Plan** ``` (pytorch-3.10) [[email protected] /data/users/andgu/float8_experimental (precompute_float8)]$ pytest test/test_fsdp2/test_fsdp2.py ========================================================= test session starts ========================================================= platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.3.0 rootdir: /data/users/andgu/float8_experimental plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, shard-0.1.2, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0 collected 8 items Running 8 items in this shard test/test_fsdp2/test_fsdp2.py ........ [100%] ========================================================== warnings summary =========================================================== test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_multi_module_parity test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_single_module_parity /data/users/andgu/float8_experimental/float8_experimental/float8_linear_utils.py:272: FutureWarning: The combination of ranks + tag as process group identifier has been deprecated. Please switch to using ProcessGroup, DeviceMesh, or group name instead. all_reduced_amax_tensor = all_reduce( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ============================================== 8 passed, 2 warnings in 121.90s (0:02:01) ============================================== ``` imported-using-ghimport Test Plan: Imported from OSS Reviewed By: weifengpy Differential Revision: D60236258 Pulled By: awgu fbshipit-source-id: 7b1e48d431dac25d534a77d64d1e5571ad3ad807
1 parent a6cef5a commit 701647b

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

float8_experimental/fsdp_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5757

5858
# inf-norm is equivalent to max(abs(w))
5959
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
60-
amax_tensor = torch.vstack(max_weights) # Partial
60+
amax_tensor = torch.stack(max_weights) # Partial
6161
# clamp is dispatched through DTensor
6262
# it will issue a single all-reduce
6363
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
6464
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
6565
if amax_tensor.dtype is torch.float16:
6666
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
67-
scales = torch.split(scale_tensor, 1) # Replicate
68-
for scale, float8_linear in zip(scales, float8_linears):
69-
float8_linear.weight._local_tensor._precomputed_scale = (
70-
scale._local_tensor.squeeze()
71-
)
67+
local_scale_tensor = scale_tensor.to_local()
68+
for i, float8_linear in enumerate(float8_linears):
69+
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
7270

7371

7472
# FSDP pads its local tensor on dim-0. The subclass should be preserved such

0 commit comments

Comments
 (0)