@@ -219,6 +219,22 @@ fn asort(device: &Device) -> Result<()> {
219219 Ok ( ( ) )
220220}
221221
222+ /// Test sorting a large tensor that exceeds 1024 elements.
223+ fn asort_big ( device : & Device ) -> Result < ( ) > {
224+ const SIZE : usize = 2000 ;
225+ let data: Vec < f32 > = ( 0 ..SIZE ) . map ( |x| ( SIZE - x) as f32 ) . collect ( ) ;
226+ let tensor = Tensor :: new ( data. as_slice ( ) , device) ?;
227+
228+ let indexes = tensor. arg_sort_last_dim ( true ) ?;
229+ let expected_indexes: Vec < u32 > = ( 0 ..SIZE ) . rev ( ) . map ( |x| x as u32 ) . collect ( ) ;
230+ assert_eq ! ( indexes. to_vec1:: <u32 >( ) ?, expected_indexes) ;
231+
232+ let indexes = tensor. arg_sort_last_dim ( false ) ?;
233+ let expected_indexes: Vec < u32 > = ( 0 ..SIZE ) . map ( |x| x as u32 ) . collect ( ) ;
234+ assert_eq ! ( indexes. to_vec1:: <u32 >( ) ?, expected_indexes) ;
235+ Ok ( ( ) )
236+ }
237+
222238fn unary_op ( device : & Device ) -> Result < ( ) > {
223239 let data = & [ [ -3f32 , 1. , 4. , -0.1 , 0.5 ] , [ 2.7 , -1.8 , -0.28 , 1.8 , 2.8 ] ] ;
224240 let tensor = Tensor :: new ( data, device) ?;
@@ -1707,6 +1723,7 @@ test_device!(
17071723test_device ! ( randn, randn_cpu, randn_gpu, randn_metal) ;
17081724test_device ! ( clamp, clamp_cpu, clamp_gpu, clamp_metal) ;
17091725test_device ! ( asort, asort_cpu, asort_gpu, asort_metal) ;
1726+ test_device ! ( asort_big, asort_big_cpu, asort_big_gpu, asort_big_metal) ;
17101727test_device ! ( var, var_cpu, var_gpu, var_metal) ;
17111728test_device ! ( zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal) ;
17121729
0 commit comments