Skip to content

Commit dab48e4

Browse files
yunjiangsterfacebook-github-bot
authored andcommitted
Avoid nan by using zeros instead of empty dummy tensor (#2648)
Summary: Pull Request resolved: #2648 This appears to solve the nan issue when we enable `torch.autograd.set_detect_anomaly(True)` The error below appears non-deterministically, and after a few training steps, presumably because `self._dummy_tensor` can be very large but not nan at the beginning, and after a few iterations reaches `nan`. ``` RuntimeError: Function 'All2All_Seq_Req_WaitBackward' returned nan values in its 0th output; num_outputs = 1; num_inputs = 0; outputs[0].shape = [1, ]; outputs[i] = nan [ torch.cuda.FloatTensor{1} ] ``` Reviewed By: iamzainhuda Differential Revision: D67535635 fbshipit-source-id: 50a70163afa8d17b3ed6f6c59c118315193c9839
1 parent 1f0681e commit dab48e4

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torchrec/distributed/comm_ops.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,13 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device) -> None:
107107
# This dummy tensor is used to build the autograd graph between
108108
# CommOp-Req and CommOp-Await. The actual forward tensors, and backwards gradient tensors
109109
# are stored in self.tensor
110-
self.dummy_tensor: torch.Tensor = torch.empty(
111-
1,
112-
requires_grad=True,
113-
device=device,
110+
# torch.zeros is a call_function, not placeholder, hence fx.trace incompatible.
111+
self.dummy_tensor: torch.Tensor = torch.zeros_like(
112+
torch.empty(
113+
1,
114+
requires_grad=True,
115+
device=device,
116+
)
114117
)
115118

116119
def _wait_impl(self) -> W:

0 commit comments

Comments
 (0)