Skip to content

Commit 32bc2f8

Browse files
committed
Update API to reflect ArrayFire 3.7.0 release
1 parent 0557ab4 commit 32bc2f8

File tree

19 files changed

+1468
-14
lines changed

19 files changed

+1468
-14
lines changed

Cargo.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,25 @@ indexing = []
2727
graphics = []
2828
image = []
2929
lapack = []
30+
machine_learning = []
3031
macros = []
3132
random = []
3233
signal = []
3334
sparse = []
3435
statistics = []
3536
vision = []
3637
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
37-
"macros", "random", "signal", "sparse", "statistics", "vision"]
38+
"machine_learning", "macros", "random", "signal", "sparse", "statistics", "vision"]
3839

3940
[dependencies]
4041
libc = "0.2"
4142
num = "0.2"
4243
lazy_static = "1.0"
44+
half = "1.5.0"
4345

4446
[dev-dependencies]
4547
float-cmp = "0.6.0"
48+
half = "1.5.0"
4649

4750
[build-dependencies]
4851
serde_json = "1.0"
@@ -85,3 +88,7 @@ path = "examples/conway.rs"
8588
[[example]]
8689
name = "fft"
8790
path = "examples/fft.rs"
91+
92+
[[example]]
93+
name = "using_half"
94+
path = "examples/using_half.rs"

examples/conway.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn main() {
77
}
88

99
fn normalise(a: &Array<f32>) -> Array<f32> {
10-
(a / (max_all(&abs(a)).0 as f32))
10+
a / (max_all(&abs(a)).0 as f32)
1111
}
1212

1313
fn conways_game_of_life() {

examples/using_half.rs

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use arrayfire::*;
2+
use half::f16;
3+
4+
fn main() {
5+
set_device(0);
6+
info();
7+
8+
let values: Vec<_> = (1u8..101).map(f32::from).collect();
9+
10+
let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
11+
12+
let hvals = Array::new(&half_values, Dim4::new(&[10, 10, 1, 1]));
13+
14+
print(&hvals);
15+
}

src/algorithm/mod.rs

+258-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::array::Array;
55
use crate::defines::{AfError, BinaryOp};
66
use crate::error::HANDLE_ERROR;
77
use crate::util::{AfArray, MutAfArray, MutDouble, MutUint};
8-
use crate::util::{HasAfEnum, RealNumber, Scanable};
8+
use crate::util::{HasAfEnum, RealNumber, ReduceByKeyInput, Scanable};
99

1010
#[allow(dead_code)]
1111
extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
5959
op: c_uint,
6060
inclusive: c_int,
6161
) -> c_int;
62+
fn af_all_true_by_key(
63+
keys_out: MutAfArray,
64+
vals_out: MutAfArray,
65+
keys: AfArray,
66+
vals: AfArray,
67+
dim: c_int,
68+
) -> c_int;
69+
fn af_any_true_by_key(
70+
keys_out: MutAfArray,
71+
vals_out: MutAfArray,
72+
keys: AfArray,
73+
vals: AfArray,
74+
dim: c_int,
75+
) -> c_int;
76+
fn af_count_by_key(
77+
keys_out: MutAfArray,
78+
vals_out: MutAfArray,
79+
keys: AfArray,
80+
vals: AfArray,
81+
dim: c_int,
82+
) -> c_int;
83+
fn af_max_by_key(
84+
keys_out: MutAfArray,
85+
vals_out: MutAfArray,
86+
keys: AfArray,
87+
vals: AfArray,
88+
dim: c_int,
89+
) -> c_int;
90+
fn af_min_by_key(
91+
keys_out: MutAfArray,
92+
vals_out: MutAfArray,
93+
keys: AfArray,
94+
vals: AfArray,
95+
dim: c_int,
96+
) -> c_int;
97+
fn af_product_by_key(
98+
keys_out: MutAfArray,
99+
vals_out: MutAfArray,
100+
keys: AfArray,
101+
vals: AfArray,
102+
dim: c_int,
103+
) -> c_int;
104+
fn af_product_by_key_nan(
105+
keys_out: MutAfArray,
106+
vals_out: MutAfArray,
107+
keys: AfArray,
108+
vals: AfArray,
109+
dim: c_int,
110+
nan_val: c_double,
111+
) -> c_int;
112+
fn af_sum_by_key(
113+
keys_out: MutAfArray,
114+
vals_out: MutAfArray,
115+
keys: AfArray,
116+
vals: AfArray,
117+
dim: c_int,
118+
) -> c_int;
119+
fn af_sum_by_key_nan(
120+
keys_out: MutAfArray,
121+
vals_out: MutAfArray,
122+
keys: AfArray,
123+
vals: AfArray,
124+
dim: c_int,
125+
nan_val: c_double,
126+
) -> c_int;
62127
}
63128

64129
macro_rules! dim_reduce_func_def {
@@ -527,7 +592,8 @@ all_reduce_func_def!(
527592
let dims = Dim4::new(&[5, 5, 1, 1]);
528593
let a = randu::<f32>(dims);
529594
print(&a);
530-
println!(\"Result : {:?}\", product_all(&a));
595+
let res = product_all(&a);
596+
println!(\"Result : {:?}\", res);
531597
```
532598
",
533599
product_all,
@@ -1137,3 +1203,193 @@ where
11371203
}
11381204
temp.into()
11391205
}
1206+
1207+
macro_rules! dim_reduce_by_key_func_def {
1208+
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1209+
#[doc=$brief_str]
1210+
/// # Parameters
1211+
///
1212+
/// - `keys` - key Array
1213+
/// - `vals` - value Array
1214+
/// - `dim` - Dimension along which the input Array is reduced
1215+
///
1216+
/// # Return Values
1217+
///
1218+
/// Tuple of Arrays, with output keys and values after reduction
1219+
///
1220+
#[doc=$ex_str]
1221+
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1222+
dim: i32
1223+
) -> (Array<KeyType>, Array<$out_type>)
1224+
where
1225+
KeyType: ReduceByKeyInput,
1226+
ValueType: HasAfEnum,
1227+
$out_type: HasAfEnum,
1228+
{
1229+
let mut out_keys: i64 = 0;
1230+
let mut out_vals: i64 = 0;
1231+
unsafe {
1232+
let err_val = $ffi_name(
1233+
&mut out_keys as MutAfArray,
1234+
&mut out_vals as MutAfArray,
1235+
keys.get() as AfArray,
1236+
vals.get() as AfArray,
1237+
dim as c_int,
1238+
);
1239+
HANDLE_ERROR(AfError::from(err_val));
1240+
}
1241+
(out_keys.into(), out_vals.into())
1242+
}
1243+
};
1244+
}
1245+
1246+
dim_reduce_by_key_func_def!(
1247+
"
1248+
Key based AND of elements along a given dimension
1249+
1250+
All positive non-zero values are considered true, while negative and zero
1251+
values are considered as false.
1252+
",
1253+
"
1254+
# Examples
1255+
```rust
1256+
use arrayfire::{Dim4, print, randu, all_true_by_key};
1257+
let dims = Dim4::new(&[5, 3, 1, 1]);
1258+
let vals = randu::<f32>(dims);
1259+
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1260+
print(&vals);
1261+
print(&keys);
1262+
let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1263+
print(&out_keys);
1264+
print(&out_vals);
1265+
```
1266+
",
1267+
all_true_by_key,
1268+
af_all_true_by_key,
1269+
ValueType::AggregateOutType
1270+
);
1271+
1272+
dim_reduce_by_key_func_def!(
1273+
"
1274+
Key based OR of elements along a given dimension
1275+
1276+
All positive non-zero values are considered true, while negative and zero
1277+
values are considered as false.
1278+
",
1279+
"
1280+
# Examples
1281+
```rust
1282+
use arrayfire::{Dim4, print, randu, any_true_by_key};
1283+
let dims = Dim4::new(&[5, 3, 1, 1]);
1284+
let vals = randu::<f32>(dims);
1285+
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1286+
print(&vals);
1287+
print(&keys);
1288+
let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1289+
print(&out_keys);
1290+
print(&out_vals);
1291+
```
1292+
",
1293+
any_true_by_key,
1294+
af_any_true_by_key,
1295+
ValueType::AggregateOutType
1296+
);
1297+
1298+
dim_reduce_by_key_func_def!(
1299+
"Find total count of elements with similar keys along a given dimension",
1300+
"",
1301+
count_by_key,
1302+
af_count_by_key,
1303+
ValueType::AggregateOutType
1304+
);
1305+
1306+
dim_reduce_by_key_func_def!(
1307+
"Find maximum among values of similar keys along a given dimension",
1308+
"",
1309+
max_by_key,
1310+
af_max_by_key,
1311+
ValueType::AggregateOutType
1312+
);
1313+
1314+
dim_reduce_by_key_func_def!(
1315+
"Find minimum among values of similar keys along a given dimension",
1316+
"",
1317+
min_by_key,
1318+
af_min_by_key,
1319+
ValueType::AggregateOutType
1320+
);
1321+
1322+
dim_reduce_by_key_func_def!(
1323+
"Find product of all values with similar keys along a given dimension",
1324+
"",
1325+
product_by_key,
1326+
af_product_by_key,
1327+
ValueType::ProductOutType
1328+
);
1329+
1330+
dim_reduce_by_key_func_def!(
1331+
"Find sum of all values with similar keys along a given dimension",
1332+
"",
1333+
sum_by_key,
1334+
af_sum_by_key,
1335+
ValueType::AggregateOutType
1336+
);
1337+
1338+
macro_rules! dim_reduce_by_key_nan_func_def {
1339+
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1340+
#[doc=$brief_str]
1341+
///
1342+
/// This version of sum by key can replaced all NaN values in the input
1343+
/// with a user provided value before performing the reduction operation.
1344+
/// # Parameters
1345+
///
1346+
/// - `keys` - key Array
1347+
/// - `vals` - value Array
1348+
/// - `dim` - Dimension along which the input Array is reduced
1349+
///
1350+
/// # Return Values
1351+
///
1352+
/// Tuple of Arrays, with output keys and values after reduction
1353+
///
1354+
#[doc=$ex_str]
1355+
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1356+
dim: i32, replace_value: f64
1357+
) -> (Array<KeyType>, Array<$out_type>)
1358+
where
1359+
KeyType: ReduceByKeyInput,
1360+
ValueType: HasAfEnum,
1361+
$out_type: HasAfEnum,
1362+
{
1363+
let mut out_keys: i64 = 0;
1364+
let mut out_vals: i64 = 0;
1365+
unsafe {
1366+
let err_val = $ffi_name(
1367+
&mut out_keys as MutAfArray,
1368+
&mut out_vals as MutAfArray,
1369+
keys.get() as AfArray,
1370+
vals.get() as AfArray,
1371+
dim as c_int,
1372+
replace_value as c_double,
1373+
);
1374+
HANDLE_ERROR(AfError::from(err_val));
1375+
}
1376+
(out_keys.into(), out_vals.into())
1377+
}
1378+
};
1379+
}
1380+
1381+
dim_reduce_by_key_nan_func_def!(
1382+
"Compute sum of all values with similar keys along a given dimension",
1383+
"",
1384+
sum_by_key_nan,
1385+
af_sum_by_key_nan,
1386+
ValueType::AggregateOutType
1387+
);
1388+
1389+
dim_reduce_by_key_nan_func_def!(
1390+
"Compute product of all values with similar keys along a given dimension",
1391+
"",
1392+
product_by_key_nan,
1393+
af_product_by_key_nan,
1394+
ValueType::ProductOutType
1395+
);

src/arith/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ extern "C" {
8585
fn af_log10(out: MutAfArray, arr: AfArray) -> c_int;
8686
fn af_log2(out: MutAfArray, arr: AfArray) -> c_int;
8787
fn af_sqrt(out: MutAfArray, arr: AfArray) -> c_int;
88+
fn af_rsqrt(out: MutAfArray, arr: AfArray) -> c_int;
8889
fn af_cbrt(out: MutAfArray, arr: AfArray) -> c_int;
8990
fn af_factorial(out: MutAfArray, arr: AfArray) -> c_int;
9091
fn af_tgamma(out: MutAfArray, arr: AfArray) -> c_int;
@@ -199,6 +200,12 @@ unary_func!("Compute the natural logarithm", log, af_log, UnaryOutType);
199200
unary_func!("Compute sin", sin, af_sin, UnaryOutType);
200201
unary_func!("Compute sinh", sinh, af_sinh, UnaryOutType);
201202
unary_func!("Compute the square root", sqrt, af_sqrt, UnaryOutType);
203+
unary_func!(
204+
"Compute the reciprocal square root",
205+
rsqrt,
206+
af_rsqrt,
207+
UnaryOutType
208+
);
202209
unary_func!("Compute tan", tan, af_tan, UnaryOutType);
203210
unary_func!("Compute tanh", tanh, af_tanh, UnaryOutType);
204211

0 commit comments

Comments
 (0)