@@ -5,7 +5,7 @@ use crate::array::Array;
5
5
use crate :: defines:: { AfError , BinaryOp } ;
6
6
use crate :: error:: HANDLE_ERROR ;
7
7
use crate :: util:: { AfArray , MutAfArray , MutDouble , MutUint } ;
8
- use crate :: util:: { HasAfEnum , RealNumber , Scanable } ;
8
+ use crate :: util:: { HasAfEnum , RealNumber , ReduceByKeyInput , Scanable } ;
9
9
10
10
#[ allow( dead_code) ]
11
11
extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
59
59
op : c_uint ,
60
60
inclusive : c_int ,
61
61
) -> c_int ;
62
+ fn af_all_true_by_key (
63
+ keys_out : MutAfArray ,
64
+ vals_out : MutAfArray ,
65
+ keys : AfArray ,
66
+ vals : AfArray ,
67
+ dim : c_int ,
68
+ ) -> c_int ;
69
+ fn af_any_true_by_key (
70
+ keys_out : MutAfArray ,
71
+ vals_out : MutAfArray ,
72
+ keys : AfArray ,
73
+ vals : AfArray ,
74
+ dim : c_int ,
75
+ ) -> c_int ;
76
+ fn af_count_by_key (
77
+ keys_out : MutAfArray ,
78
+ vals_out : MutAfArray ,
79
+ keys : AfArray ,
80
+ vals : AfArray ,
81
+ dim : c_int ,
82
+ ) -> c_int ;
83
+ fn af_max_by_key (
84
+ keys_out : MutAfArray ,
85
+ vals_out : MutAfArray ,
86
+ keys : AfArray ,
87
+ vals : AfArray ,
88
+ dim : c_int ,
89
+ ) -> c_int ;
90
+ fn af_min_by_key (
91
+ keys_out : MutAfArray ,
92
+ vals_out : MutAfArray ,
93
+ keys : AfArray ,
94
+ vals : AfArray ,
95
+ dim : c_int ,
96
+ ) -> c_int ;
97
+ fn af_product_by_key (
98
+ keys_out : MutAfArray ,
99
+ vals_out : MutAfArray ,
100
+ keys : AfArray ,
101
+ vals : AfArray ,
102
+ dim : c_int ,
103
+ ) -> c_int ;
104
+ fn af_product_by_key_nan (
105
+ keys_out : MutAfArray ,
106
+ vals_out : MutAfArray ,
107
+ keys : AfArray ,
108
+ vals : AfArray ,
109
+ dim : c_int ,
110
+ nan_val : c_double ,
111
+ ) -> c_int ;
112
+ fn af_sum_by_key (
113
+ keys_out : MutAfArray ,
114
+ vals_out : MutAfArray ,
115
+ keys : AfArray ,
116
+ vals : AfArray ,
117
+ dim : c_int ,
118
+ ) -> c_int ;
119
+ fn af_sum_by_key_nan (
120
+ keys_out : MutAfArray ,
121
+ vals_out : MutAfArray ,
122
+ keys : AfArray ,
123
+ vals : AfArray ,
124
+ dim : c_int ,
125
+ nan_val : c_double ,
126
+ ) -> c_int ;
62
127
}
63
128
64
129
macro_rules! dim_reduce_func_def {
@@ -527,7 +592,8 @@ all_reduce_func_def!(
527
592
let dims = Dim4::new(&[5, 5, 1, 1]);
528
593
let a = randu::<f32>(dims);
529
594
print(&a);
530
- println!(\" Result : {:?}\" , product_all(&a));
595
+ let res = product_all(&a);
596
+ println!(\" Result : {:?}\" , res);
531
597
```
532
598
" ,
533
599
product_all,
@@ -1137,3 +1203,193 @@ where
1137
1203
}
1138
1204
temp. into ( )
1139
1205
}
1206
+
1207
+ macro_rules! dim_reduce_by_key_func_def {
1208
+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1209
+ #[ doc=$brief_str]
1210
+ /// # Parameters
1211
+ ///
1212
+ /// - `keys` - key Array
1213
+ /// - `vals` - value Array
1214
+ /// - `dim` - Dimension along which the input Array is reduced
1215
+ ///
1216
+ /// # Return Values
1217
+ ///
1218
+ /// Tuple of Arrays, with output keys and values after reduction
1219
+ ///
1220
+ #[ doc=$ex_str]
1221
+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1222
+ dim: i32
1223
+ ) -> ( Array <KeyType >, Array <$out_type>)
1224
+ where
1225
+ KeyType : ReduceByKeyInput ,
1226
+ ValueType : HasAfEnum ,
1227
+ $out_type: HasAfEnum ,
1228
+ {
1229
+ let mut out_keys: i64 = 0 ;
1230
+ let mut out_vals: i64 = 0 ;
1231
+ unsafe {
1232
+ let err_val = $ffi_name(
1233
+ & mut out_keys as MutAfArray ,
1234
+ & mut out_vals as MutAfArray ,
1235
+ keys. get( ) as AfArray ,
1236
+ vals. get( ) as AfArray ,
1237
+ dim as c_int,
1238
+ ) ;
1239
+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1240
+ }
1241
+ ( out_keys. into( ) , out_vals. into( ) )
1242
+ }
1243
+ } ;
1244
+ }
1245
+
1246
+ dim_reduce_by_key_func_def ! (
1247
+ "
1248
+ Key based AND of elements along a given dimension
1249
+
1250
+ All positive non-zero values are considered true, while negative and zero
1251
+ values are considered as false.
1252
+ " ,
1253
+ "
1254
+ # Examples
1255
+ ```rust
1256
+ use arrayfire::{Dim4, print, randu, all_true_by_key};
1257
+ let dims = Dim4::new(&[5, 3, 1, 1]);
1258
+ let vals = randu::<f32>(dims);
1259
+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1260
+ print(&vals);
1261
+ print(&keys);
1262
+ let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1263
+ print(&out_keys);
1264
+ print(&out_vals);
1265
+ ```
1266
+ " ,
1267
+ all_true_by_key,
1268
+ af_all_true_by_key,
1269
+ ValueType :: AggregateOutType
1270
+ ) ;
1271
+
1272
+ dim_reduce_by_key_func_def ! (
1273
+ "
1274
+ Key based OR of elements along a given dimension
1275
+
1276
+ All positive non-zero values are considered true, while negative and zero
1277
+ values are considered as false.
1278
+ " ,
1279
+ "
1280
+ # Examples
1281
+ ```rust
1282
+ use arrayfire::{Dim4, print, randu, any_true_by_key};
1283
+ let dims = Dim4::new(&[5, 3, 1, 1]);
1284
+ let vals = randu::<f32>(dims);
1285
+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1286
+ print(&vals);
1287
+ print(&keys);
1288
+ let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1289
+ print(&out_keys);
1290
+ print(&out_vals);
1291
+ ```
1292
+ " ,
1293
+ any_true_by_key,
1294
+ af_any_true_by_key,
1295
+ ValueType :: AggregateOutType
1296
+ ) ;
1297
+
1298
+ dim_reduce_by_key_func_def ! (
1299
+ "Find total count of elements with similar keys along a given dimension" ,
1300
+ "" ,
1301
+ count_by_key,
1302
+ af_count_by_key,
1303
+ ValueType :: AggregateOutType
1304
+ ) ;
1305
+
1306
+ dim_reduce_by_key_func_def ! (
1307
+ "Find maximum among values of similar keys along a given dimension" ,
1308
+ "" ,
1309
+ max_by_key,
1310
+ af_max_by_key,
1311
+ ValueType :: AggregateOutType
1312
+ ) ;
1313
+
1314
+ dim_reduce_by_key_func_def ! (
1315
+ "Find minimum among values of similar keys along a given dimension" ,
1316
+ "" ,
1317
+ min_by_key,
1318
+ af_min_by_key,
1319
+ ValueType :: AggregateOutType
1320
+ ) ;
1321
+
1322
+ dim_reduce_by_key_func_def ! (
1323
+ "Find product of all values with similar keys along a given dimension" ,
1324
+ "" ,
1325
+ product_by_key,
1326
+ af_product_by_key,
1327
+ ValueType :: ProductOutType
1328
+ ) ;
1329
+
1330
+ dim_reduce_by_key_func_def ! (
1331
+ "Find sum of all values with similar keys along a given dimension" ,
1332
+ "" ,
1333
+ sum_by_key,
1334
+ af_sum_by_key,
1335
+ ValueType :: AggregateOutType
1336
+ ) ;
1337
+
1338
+ macro_rules! dim_reduce_by_key_nan_func_def {
1339
+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1340
+ #[ doc=$brief_str]
1341
+ ///
1342
+ /// This version of sum by key can replaced all NaN values in the input
1343
+ /// with a user provided value before performing the reduction operation.
1344
+ /// # Parameters
1345
+ ///
1346
+ /// - `keys` - key Array
1347
+ /// - `vals` - value Array
1348
+ /// - `dim` - Dimension along which the input Array is reduced
1349
+ ///
1350
+ /// # Return Values
1351
+ ///
1352
+ /// Tuple of Arrays, with output keys and values after reduction
1353
+ ///
1354
+ #[ doc=$ex_str]
1355
+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1356
+ dim: i32 , replace_value: f64
1357
+ ) -> ( Array <KeyType >, Array <$out_type>)
1358
+ where
1359
+ KeyType : ReduceByKeyInput ,
1360
+ ValueType : HasAfEnum ,
1361
+ $out_type: HasAfEnum ,
1362
+ {
1363
+ let mut out_keys: i64 = 0 ;
1364
+ let mut out_vals: i64 = 0 ;
1365
+ unsafe {
1366
+ let err_val = $ffi_name(
1367
+ & mut out_keys as MutAfArray ,
1368
+ & mut out_vals as MutAfArray ,
1369
+ keys. get( ) as AfArray ,
1370
+ vals. get( ) as AfArray ,
1371
+ dim as c_int,
1372
+ replace_value as c_double,
1373
+ ) ;
1374
+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1375
+ }
1376
+ ( out_keys. into( ) , out_vals. into( ) )
1377
+ }
1378
+ } ;
1379
+ }
1380
+
1381
+ dim_reduce_by_key_nan_func_def ! (
1382
+ "Compute sum of all values with similar keys along a given dimension" ,
1383
+ "" ,
1384
+ sum_by_key_nan,
1385
+ af_sum_by_key_nan,
1386
+ ValueType :: AggregateOutType
1387
+ ) ;
1388
+
1389
+ dim_reduce_by_key_nan_func_def ! (
1390
+ "Compute product of all values with similar keys along a given dimension" ,
1391
+ "" ,
1392
+ product_by_key_nan,
1393
+ af_product_by_key_nan,
1394
+ ValueType :: ProductOutType
1395
+ ) ;
0 commit comments