Skip to content

Commit 99f6c4c

Browse files
committed
candle-core: add asort_big test demonstrating issues with sorting larger
tensors on cuda
1 parent 63437a4 commit 99f6c4c

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

candle-core/tests/tensor_tests.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,22 @@ fn asort(device: &Device) -> Result<()> {
219219
Ok(())
220220
}
221221

222+
/// Test sorting a large tensor that exceeds 1024 elements.
223+
fn asort_big(device: &Device) -> Result<()> {
224+
const SIZE: usize = 2000;
225+
let data: Vec<f32> = (0..SIZE).map(|x| (SIZE - x) as f32).collect();
226+
let tensor = Tensor::new(data.as_slice(), device)?;
227+
228+
let indexes = tensor.arg_sort_last_dim(true)?;
229+
let expected_indexes: Vec<u32> = (0..SIZE).rev().map(|x| x as u32).collect();
230+
assert_eq!(indexes.to_vec1::<u32>()?, expected_indexes);
231+
232+
let indexes = tensor.arg_sort_last_dim(false)?;
233+
let expected_indexes: Vec<u32> = (0..SIZE).map(|x| x as u32).collect();
234+
assert_eq!(indexes.to_vec1::<u32>()?, expected_indexes);
235+
Ok(())
236+
}
237+
222238
fn unary_op(device: &Device) -> Result<()> {
223239
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
224240
let tensor = Tensor::new(data, device)?;
@@ -1707,6 +1723,7 @@ test_device!(
17071723
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
17081724
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
17091725
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
1726+
test_device!(asort_big, asort_big_cpu, asort_big_gpu, asort_big_metal);
17101727
test_device!(var, var_cpu, var_gpu, var_metal);
17111728
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
17121729

0 commit comments

Comments
 (0)