Skip to content

Commit 1d7397a

Browse files
DDElepytorchmergebot
authored andcommitted
[Inductor] Avoid tensor slice overflow for large step (pytorch#147433)
Fixes pytorch#147071 Currently, if step is a value very close to INT64_MAX, the calculation of slice output length will overflow. This PR tries to fix this problem and thus fix pytorch#147071. Pull Request resolved: pytorch#147433 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
1 parent 9c506aa commit 1d7397a

File tree

5 files changed

+32
-4
lines changed

5 files changed

+32
-4
lines changed

aten/src/ATen/native/TensorShape.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -3059,7 +3059,8 @@ Tensor slice(
30593059
}
30603060
auto storage_offset = self.storage_offset() + start_val * strides[dim];
30613061
auto len = end_val - start_val;
3062-
sizes[dim] = (len + step - 1) / step; // round-up
3062+
sizes[dim] =
3063+
(len == 0) ? 0 : (1 + (len - 1) / step); // round-up, avoiding overflow
30633064
strides[dim] *= step;
30643065

30653066
Tensor result;

test/expect/HasDecompTest.test_has_decomposition.expect

-2
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,6 @@ aten::set_
11671167
aten::set_.source_Storage
11681168
aten::set_.source_Storage_storage_offset
11691169
aten::set_.source_Tensor
1170-
aten::slice_copy.Tensor
1171-
aten::slice_copy.Tensor_out
11721170
aten::slice_inverse
11731171
aten::slow_conv3d_forward
11741172
aten::slow_conv3d_forward.output

test/inductor/test_torchinductor.py

+11
Original file line numberDiff line numberDiff line change
@@ -12527,6 +12527,17 @@ def f(x):
1252712527
ms = do_bench(lambda: opt_f(x))
1252812528
print(f"{ms=:.3f}")
1252912529

12530+
def test_slice_overflow(self):
12531+
# https://github.com/pytorch/pytorch/issues/147071
12532+
def f(input):
12533+
var = torch.slice_copy(
12534+
input, dim=0, start=449, end=None, step=9223372036854775807
12535+
)
12536+
return torch.reciprocal(var)
12537+
12538+
input = torch.randn((875,))
12539+
self.assertEqual(torch.compile(f)(input), f(input))
12540+
1253012541
@torch._inductor.config.patch("graph_partition", True)
1253112542
def test_graph_partition_no_inputs(self):
1253212543
def foo():

torch/_decomp/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
477477
aten.sinc,
478478
aten.sinc_,
479479
aten.slice_backward,
480+
aten.slice_copy,
480481
aten.smooth_l1_loss,
481482
aten.smooth_l1_loss_backward,
482483
aten.soft_margin_loss,

torch/_decomp/decompositions.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def slice_forward(
748748

749749
storage_offset = self.storage_offset() + start_val * strides[dim]
750750
len = end_val - start_val
751-
sizes[dim] = (len + step - 1) // step
751+
sizes[dim] = -(len // -step) # round-up, avoiding overflow
752752
strides[dim] *= step
753753

754754
if self.is_quantized:
@@ -759,6 +759,23 @@ def slice_forward(
759759
return self.as_strided(sizes, strides, storage_offset)
760760

761761

762+
@register_decomposition([aten.slice_copy.Tensor, aten.slice_copy.Tensor_out])
763+
def slice_copy(
764+
self: Tensor,
765+
dim: int = 0,
766+
start: Optional[int] = None,
767+
end: Optional[int] = None,
768+
step: int = 1,
769+
out: Optional[Tensor] = None,
770+
):
771+
_slice = slice_forward(self, dim, start, end, step)
772+
slice_clone = _slice.clone(memory_format=torch.contiguous_format)
773+
if out is None:
774+
return slice_clone
775+
else:
776+
return _safe_copy_out(copy_from=slice_clone, copy_to=out, exact_dtype=True)
777+
778+
762779
def _normalize_start_end(
763780
x: Tensor, dim: int, start: Optional[int], end: Optional[int]
764781
) -> tuple[int, int]:

0 commit comments

Comments
 (0)