Skip to content

Commit e4bb565

Browse files
nairbvfacebook-github-bot
authored andcommitted
Preemptively test for out-of-order length. (pytorch#13933)
Summary: torch.nn.utils.rnn.pack_padded_sequence segment fault if not in decreasing order pytorch#13324 We were seeing this segfault on throw, pre-emptively checking avoids this: *** Error in `/home/bvaughan/anaconda3/bin/python': double free or corruption (!prev): 0x00005555566e7510 *** Pull Request resolved: pytorch#13933 Differential Revision: D13090389 Pulled By: nairbv fbshipit-source-id: 6f6b319e74cb55830be799e9c46bc33aa59256d8
1 parent c7a247f commit e4bb565

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

aten/src/ATen/native/PackedSequence.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten
2222
AT_CHECK(lengths[batch_size - 1] > 0,
2323
"Length of all samples has to be greater than 0, but found an element "
2424
"in 'lengths' that is <= 0");
25+
for(auto i = 0; i < batch_size - 1; i++) {
26+
if (lengths[batch_size - 1 - i] > lengths[batch_size - 2 - i]) {
27+
AT_ERROR("'lengths' array has to be sorted in decreasing order");
28+
}
29+
}
2530

2631
std::vector<at::Tensor> steps;
2732
steps.reserve(batch_size);
@@ -67,9 +72,8 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten
6772
(*batch_sizes++) = current_batch_size;
6873
}
6974
prev_l = l;
70-
} else if (prev_l > l) {
71-
AT_ERROR("'lengths' array has to be sorted in decreasing order");
7275
}
76+
AT_CHECK(l >= prev_l);
7377
}
7478

7579
return std::make_tuple(at::cat(steps), batch_sizes_t);

test/test_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ def test_cuda_mask(self):
143143
unpacked, _ = rnn_utils.pad_packed_sequence(packed)
144144
self.assertEqual(unpacked.type(), cuda_type_str)
145145

146+
def test_wrong_order(self):
147+
# https://github.com/pytorch/pytorch/issues/13324
148+
a = torch.ones(25, 300)
149+
b = torch.ones(22, 300)
150+
b_a = rnn_utils.pad_sequence([b, a])
151+
self.assertRaises(RuntimeError, lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25]))
152+
146153
def test_total_length(self):
147154
padded, lengths = self._padded_sequence(torch.FloatTensor)
148155
max_length = max(lengths)

0 commit comments

Comments
 (0)