Skip to content

Commit 25db86c

Browse files
Ailing Zhangfacebook-github-bot
Ailing Zhang
authored andcommitted
Fix isfinite for int input (pytorch#12750)
Summary: `torch.isfinite()` used to crash on int inputs. ``` >>> import torch >>> a = torch.tensor([1, 2]) >>> torch.isfinite(a) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/scratch/pytorch/torch/functional.py", line 262, in isfinite return (tensor == tensor) & (tensor.abs() != inf) RuntimeError: value cannot be converted to type int64_t without overflow: inf ``` But this is a easy special case and numpy also supports it. ``` >>> import numpy as np >>> a = np.array([1, 2]) >>> a.dtype dtype('int64') >>> np.isfinite(a) array([ True, True], dtype=bool) ``` So added a hacky line to handle non-floating-point input. Since pytorch raises exception when overflow, we can safely assume all valid int tensors are infinite numbers. Pull Request resolved: pytorch#12750 Differential Revision: D10428204 Pulled By: ailzhang fbshipit-source-id: f39b2d0975762c91cdea23c766ff1e21d85d57a5
1 parent 9a76e84 commit 25db86c

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/test_torch.py

+4
Original file line numberDiff line numberDiff line change
@@ -5201,6 +5201,10 @@ def test_isfinite(self):
52015201
x = torch.Tensor([1, inf, 2, -inf, nan, -10])
52025202
self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 0, 1, 0, 0, 1]))
52035203

5204+
def test_isfinite_int(self):
5205+
x = torch.tensor([1, 2, 3])
5206+
self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 1, 1]))
5207+
52045208
def test_isinf(self):
52055209
x = torch.Tensor([1, inf, 2, -inf, nan])
52065210
self.assertEqual(torch.isinf(x), torch.ByteTensor([0, 1, 0, 1, 0]))

torch/functional.py

+7
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ def isfinite(tensor):
259259
"""
260260
if not isinstance(tensor, torch.Tensor):
261261
raise ValueError("The argument is not a tensor", str(tensor))
262+
263+
# Support int input, nan and inf are concepts in floating point numbers.
264+
# Numpy uses type 'Object' when the int overflows long, but we don't
265+
# have a similar concept. It's safe to assume any created LongTensor doesn't
266+
# overflow and it's finite.
267+
if not tensor.is_floating_point():
268+
return torch.ones_like(tensor, dtype=torch.uint8)
262269
return (tensor == tensor) & (tensor.abs() != inf)
263270

264271

0 commit comments

Comments
 (0)