Skip to content

Commit 8e05749

Browse files
PenXLapytorchmergebot
authored andcommitted
Fix integer overflow bug in triu/tril for large diagonal values (pytorch#153240)
This PR fixes a bug in the implementation of `apply_triu_tril_single` where using extremely large values for the diagonal argument (e.g. `diagonal=9223372036854775807`) could result in integer overflow and incorrect results. The masking logic is re-written to avoid this issue by always iterating over all columns, ensuring correctness even for large or extreme diagonal values. Example of the original incorrect behavior: ```python a = torch.ones(5,5) torch.triu(a, 9223372036854775807) # Before: # tensor([[0., 0., 0., 0., 0.], # [1., 1., 1., 1., 1.], # [1., 1., 1., 1., 1.], # [1., 1., 1., 1., 1.], # [1., 1., 1., 1., 1.]]) ``` The new implementation guards against overflow and produces correct results for all valid input values. Pull Request resolved: pytorch#153240 Approved by: https://github.com/albanD
1 parent b334a5a commit 8e05749

2 files changed

Lines changed: 34 additions & 0 deletions

File tree

aten/src/ATen/native/TriangularOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ void apply_triu_tril_single(
5252
int64_t self_col_stride,
5353
bool upper) {
5454
constexpr int64_t zero = 0;
55+
k = std::clamp(k, -n, m); // Clamp k to [-n, m] to prevent i + k arithmetic overflow, especially if k approaches INT64_MAX/INT64_MIN.
5556

5657
if (upper) {
5758
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {

test/test_linalg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9840,6 +9840,39 @@ def test_matmul_mv(self, device, dtype):
98409840
C = torch.matmul(A, B)
98419841
self.assertEqual(C, B.sum().expand(B.shape))
98429842

9843+
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
9844+
def test_triu_tril_extreme_k_values(self, device, dtype):
9845+
"""
9846+
Test triu/tril with extreme k values to verify overflow fix.
9847+
Regression test for https://github.com/pytorch/pytorch/pull/153240
9848+
"""
9849+
# Create test matrices
9850+
a = make_tensor((5, 5), dtype=dtype, device=device)
9851+
9852+
# Test extreme positive k value
9853+
k_max = 9223372036854775807
9854+
result_triu_max = torch.triu(a, k_max)
9855+
result_tril_max = torch.tril(a, k_max)
9856+
9857+
# With k = INT64_MAX, triu should return all zeros (since i + k will exceed matrix bounds for all i,j)
9858+
# and tril should return the full matrix (since i + k + 1 will exceed matrix bounds for all i,j)
9859+
expected_triu_max = torch.zeros_like(a)
9860+
expected_tril_max = a.clone()
9861+
self.assertEqual(result_triu_max, expected_triu_max)
9862+
self.assertEqual(result_tril_max, expected_tril_max)
9863+
9864+
# Test extreme negative k value
9865+
k_min = -9223372036854775808
9866+
result_triu_min = torch.triu(a, k_min)
9867+
result_tril_min = torch.tril(a, k_min)
9868+
9869+
# With k = INT64_MIN, triu should return the full matrix (since i + k will be negative for all i,j)
9870+
# and tril should return all zeros (since i + k + 1 will be negative for all i,j)
9871+
expected_triu_min = a.clone()
9872+
expected_tril_min = torch.zeros_like(a)
9873+
self.assertEqual(result_triu_min, expected_triu_min)
9874+
self.assertEqual(result_tril_min, expected_tril_min)
9875+
98439876
@dtypes(torch.float, torch.double)
98449877
@precisionOverride({torch.float32: 1e-4})
98459878
def test_1_sized_with_0_strided(self, device, dtype):

0 commit comments

Comments
 (0)