You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Reduced CPU overhead in precompute_float8_dynamic_scale_for_fsdp (#331)
Summary:
Pull Request resolved: #331
**Description**
For Llama3-8B on 8xH100 profiling with `with_stack=True` (which does add overhead), the `precompute_float8_dynamic_scale_for_fsdp` CPU time decreases from 24 ms to 15 ms.
Before:
<img width="600" alt="Screenshot 2024-07-25 at 10 16 38 AM" src="https://github.com/user-attachments/assets/5d2384a0-6864-4bdc-91db-90cae809c702">
After:
<img width="638" alt="Screenshot 2024-07-25 at 10 17 00 AM" src="https://github.com/user-attachments/assets/1dbf3b2e-a576-4cdf-ac4f-06ae96020c38">
**Test Plan**
```
(pytorch-3.10) [[email protected] /data/users/andgu/float8_experimental (precompute_float8)]$ pytest test/test_fsdp2/test_fsdp2.py
========================================================= test session starts =========================================================
platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.3.0
rootdir: /data/users/andgu/float8_experimental
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, shard-0.1.2, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0
collected 8 items
Running 8 items in this shard
test/test_fsdp2/test_fsdp2.py ........ [100%]
========================================================== warnings summary ===========================================================
test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_multi_module_parity
test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_single_module_parity
/data/users/andgu/float8_experimental/float8_experimental/float8_linear_utils.py:272: FutureWarning: The combination of ranks + tag as process group identifier has been deprecated. Please switch to using ProcessGroup, DeviceMesh, or group name instead.
all_reduced_amax_tensor = all_reduce(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================== 8 passed, 2 warnings in 121.90s (0:02:01) ==============================================
```
imported-using-ghimport
Test Plan: Imported from OSS
Reviewed By: weifengpy
Differential Revision: D60236258
Pulled By: awgu
fbshipit-source-id: 7b1e48d431dac25d534a77d64d1e5571ad3ad807
0 commit comments