From d95c14878f41a4f20df349de754f10c1fe3bce54 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 27 Jan 2025 11:13:23 -0800 Subject: [PATCH] add async nan check utils (#965) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/965 Reviewed By: galrotem Differential Revision: D68530393 fbshipit-source-id: 8afdee2dc74a28b19c0a16eebf0b585f333fdc12 --- tests/utils/test_nan.py | 49 +++++++++++++++++ torchtnt/utils/__init__.py | 3 ++ torchtnt/utils/nan.py | 107 +++++++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+) create mode 100644 tests/utils/test_nan.py create mode 100644 torchtnt/utils/nan.py diff --git a/tests/utils/test_nan.py b/tests/utils/test_nan.py new file mode 100644 index 0000000000..deff7f1362 --- /dev/null +++ b/tests/utils/test_nan.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch + +from torchtnt.utils.nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph + + +class NaNFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently + def forward(ctx, input): + return input.clone() + + @staticmethod + # pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently + def backward(ctx, grad_output): + return torch.tensor([float("nan")], device="cpu") + + +class NanHookTest(unittest.TestCase): + def test_register_nan_hooks_on_whole_graph(self) -> None: + x = torch.tensor([1.0], device="cpu", requires_grad=True) + out = NaNFunction.apply(x) + + # no error is thrown + out.backward() + + _ = register_nan_hooks_on_whole_graph([out]) + with self.assertRaisesRegex(RuntimeError, "Detected NaN"): + out.backward() + + def test_check_for_nan_or_inf(self) -> None: + tensor = torch.tensor([float("nan")], device="cpu") + + with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"): + check_for_nan_or_inf(tensor) + + tensor = torch.tensor([float("inf")], device="cpu") + with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"): + check_for_nan_or_inf(tensor) diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index fb8098c360..2e3dc8ae65 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -51,6 +51,7 @@ ModuleSummary, prune_module_summary, ) +from .nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph from .oom import ( attach_oom_observer, is_out_of_cpu_memory, @@ -89,6 +90,8 @@ ) __all__ = [ + "check_for_nan_or_inf", + "register_nan_hooks_on_whole_graph", "IsNaNEvaluator", "ThresholdEvaluator", "CheckpointPath", diff --git a/torchtnt/utils/nan.py b/torchtnt/utils/nan.py new file mode 100644 index 0000000000..a13d86e952 --- /dev/null +++ b/torchtnt/utils/nan.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import deque +from typing import Callable, Iterator, List, Optional, Sequence, Union + +import torch +from pyre_extensions import none_throws +from torch.autograd.graph import GradientEdge, Node +from torch.utils.hooks import RemovableHandle + + +def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, GradientEdge]) -> Node: + if isinstance(t, torch.Tensor): + return none_throws(t.grad_fn) + else: + # pyre-ignore Undefined attribute [16]: `GradientEdge` has no attribute `function`. + return t.function if t is not None else None + + +def register_nan_hooks_on_whole_graph( # noqa: C901 + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]] +) -> Callable[[], None]: + """ + Registers a nan hook on the whole graph of the given tensors. The hook will throw error if a nan is detected. + + This is useful if you want training to halt when a nan is detected during autograd process (ie loss is inf or nan). + + Usage: + + >>> class NaNFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input.clone() + + @staticmethod + def backward(ctx, grad_output): + return torch.tensor([float("nan")], device="cpu") + >>> x = torch.tensor([1.0], device="cpu", requires_grad=True) + >>> out = NaNFunction.apply(x) + >>> _ = register_nan_hooks_on_whole_graph([out]) + >>> out.backward() + RuntimeError: Detected NaN in 'grad_inputs[0]' after executing Node + + """ + + grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) + + def iter_graph(roots: List[torch.autograd.graph.Node]) -> Iterator[Node]: + if not roots: + return + seen = set() + q = deque() + for node in roots: + if node is not None and node not in seen: + seen.add(node) + q.append(node) + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn is None or fn in seen: + continue + seen.add(fn) + q.append(fn) + yield node + + def _assert_no_nan_tensor(t: Optional[torch.Tensor], msg: str) -> None: + if t is not None: + torch._assert_async(torch.logical_not(torch.any(torch.isnan(t))), msg) + + def posthook( + grad_inputs: Sequence[Optional[torch.Tensor]], + grad_outputs: Sequence[Optional[torch.Tensor]], + ) -> None: + node = torch._C._current_autograd_node() + for i, g_in in enumerate(grad_inputs): + _assert_no_nan_tensor( + g_in, f"Detected NaN in 'grad_inputs[{i}]' after executing Node: {node}" + ) + + handles: List[RemovableHandle] = [] + for node in iter_graph(grad_fns): + posthandle = node.register_hook(posthook) + handles.append(posthandle) + + def unregister_hooks() -> None: + for handle in handles: + handle.remove() + + return unregister_hooks + + +def check_for_nan_or_inf( + tensor: torch.Tensor, msg: str = "Detected NaN or Inf in tensor" +) -> None: + """ + Asynchronously assert that the tensor is neither NaN nor infinity. This will + produce a cuda device side assert error if tensor on gpu. + """ + + torch._assert_async( + torch.logical_not(torch.any(torch.isnan(tensor) | torch.isinf(tensor))), + msg, + )