Skip to content

Commit

Permalink
add async nan check utils (#965)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #965

Reviewed By: galrotem

Differential Revision: D68530393

fbshipit-source-id: 8afdee2dc74a28b19c0a16eebf0b585f333fdc12
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 27, 2025
1 parent 06e6207 commit d95c148
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/utils/test_nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch

from torchtnt.utils.nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph


class NaNFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently
def forward(ctx, input):
return input.clone()

@staticmethod
# pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently
def backward(ctx, grad_output):
return torch.tensor([float("nan")], device="cpu")


class NanHookTest(unittest.TestCase):
def test_register_nan_hooks_on_whole_graph(self) -> None:
x = torch.tensor([1.0], device="cpu", requires_grad=True)
out = NaNFunction.apply(x)

# no error is thrown
out.backward()

_ = register_nan_hooks_on_whole_graph([out])
with self.assertRaisesRegex(RuntimeError, "Detected NaN"):
out.backward()

def test_check_for_nan_or_inf(self) -> None:
tensor = torch.tensor([float("nan")], device="cpu")

with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"):
check_for_nan_or_inf(tensor)

tensor = torch.tensor([float("inf")], device="cpu")
with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"):
check_for_nan_or_inf(tensor)
3 changes: 3 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ModuleSummary,
prune_module_summary,
)
from .nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph
from .oom import (
attach_oom_observer,
is_out_of_cpu_memory,
Expand Down Expand Up @@ -89,6 +90,8 @@
)

__all__ = [
"check_for_nan_or_inf",
"register_nan_hooks_on_whole_graph",
"IsNaNEvaluator",
"ThresholdEvaluator",
"CheckpointPath",
Expand Down
107 changes: 107 additions & 0 deletions torchtnt/utils/nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import deque
from typing import Callable, Iterator, List, Optional, Sequence, Union

import torch
from pyre_extensions import none_throws
from torch.autograd.graph import GradientEdge, Node
from torch.utils.hooks import RemovableHandle


def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, GradientEdge]) -> Node:
if isinstance(t, torch.Tensor):
return none_throws(t.grad_fn)
else:
# pyre-ignore Undefined attribute [16]: `GradientEdge` has no attribute `function`.
return t.function if t is not None else None


def register_nan_hooks_on_whole_graph( # noqa: C901
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]]
) -> Callable[[], None]:
"""
Registers a nan hook on the whole graph of the given tensors. The hook will throw error if a nan is detected.
This is useful if you want training to halt when a nan is detected during autograd process (ie loss is inf or nan).
Usage:
>>> class NaNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input.clone()
@staticmethod
def backward(ctx, grad_output):
return torch.tensor([float("nan")], device="cpu")
>>> x = torch.tensor([1.0], device="cpu", requires_grad=True)
>>> out = NaNFunction.apply(x)
>>> _ = register_nan_hooks_on_whole_graph([out])
>>> out.backward()
RuntimeError: Detected NaN in 'grad_inputs[0]' after executing Node
"""

grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))

def iter_graph(roots: List[torch.autograd.graph.Node]) -> Iterator[Node]:
if not roots:
return
seen = set()
q = deque()
for node in roots:
if node is not None and node not in seen:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _ in node.next_functions:
if fn is None or fn in seen:
continue
seen.add(fn)
q.append(fn)
yield node

def _assert_no_nan_tensor(t: Optional[torch.Tensor], msg: str) -> None:
if t is not None:
torch._assert_async(torch.logical_not(torch.any(torch.isnan(t))), msg)

def posthook(
grad_inputs: Sequence[Optional[torch.Tensor]],
grad_outputs: Sequence[Optional[torch.Tensor]],
) -> None:
node = torch._C._current_autograd_node()
for i, g_in in enumerate(grad_inputs):
_assert_no_nan_tensor(
g_in, f"Detected NaN in 'grad_inputs[{i}]' after executing Node: {node}"
)

handles: List[RemovableHandle] = []
for node in iter_graph(grad_fns):
posthandle = node.register_hook(posthook)
handles.append(posthandle)

def unregister_hooks() -> None:
for handle in handles:
handle.remove()

return unregister_hooks


def check_for_nan_or_inf(
tensor: torch.Tensor, msg: str = "Detected NaN or Inf in tensor"
) -> None:
"""
Asynchronously assert that the tensor is neither NaN nor infinity. This will
produce a cuda device side assert error if tensor on gpu.
"""

torch._assert_async(
torch.logical_not(torch.any(torch.isnan(tensor) | torch.isinf(tensor))),
msg,
)

0 comments on commit d95c148

Please sign in to comment.