Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 31fba04

Browse files
y-sqfacebook-github-bot
authored andcommitted
Fix an issue in sync_amax (#169)
Summary: To fix this error ``` RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor []] is at version 1; expected version 0 instead. ``` ---- Also tried ``` torch.no_grad() def sync_float8_amax_and_scale_history( ``` which didn't work. ---- We can look into if there are any better ways to fix this. Pull Request resolved: #169 Test Plan: ./test/test_fsdp.sh Reviewed By: vkuzo Differential Revision: D52373985 Pulled By: y-sq fbshipit-source-id: a25f4b0fee21dd5801c444b28f8a2f878bbafa35
1 parent f4812ee commit 31fba04

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def sync_float8_amax_and_scale_history(
163163
# 1. in distributed contexts, syncs amax values across workers
164164
#
165165
if dist.is_initialized():
166-
child.fp8_amax_x = fp8_amax_x_tensor[idx]
167-
child.fp8_amax_w = fp8_amax_w_tensor[idx]
168-
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx]
166+
child.fp8_amax_x = fp8_amax_x_tensor[idx].clone()
167+
child.fp8_amax_w = fp8_amax_w_tensor[idx].clone()
168+
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone()
169169

170170
#
171171
# 2. adds the `amax` values to history

0 commit comments

Comments
 (0)