Skip to content

Commit 8d3a82f

Browse files
committed
New API to check for half support on given device
Also fixes a doc test in Array struct implementation
1 parent 65cc4ff commit 8d3a82f

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/core/array.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,20 @@ where
189189
/// An example of creating an Array from half::f16 array
190190
///
191191
/// ```rust
192-
/// use arrayfire::{Array, Dim4, print};
192+
/// use arrayfire::{Array, Dim4, is_half_available, print};
193193
/// use half::f16;
194194
///
195195
/// let values: [f32; 3] = [1.0, 2.0, 3.0];
196196
///
197-
/// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
197+
/// if is_half_available(0) { // Default device is 0, hence the argument
198+
/// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
198199
///
199-
/// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
200+
/// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
200201
///
201-
/// print(&hvals);
202+
/// print(&hvals);
203+
/// } else {
204+
/// println!("Half support isn't available on this device");
205+
/// }
202206
/// ```
203207
///
204208
pub fn new(slice: &[T], dims: Dim4) -> Self {

src/core/device.rs

+19
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ extern "C" {
3636

3737
fn af_alloc_pinned(non_pagable_ptr: *mut void_ptr, bytes: dim_t) -> c_int;
3838
fn af_free_pinned(non_pagable_ptr: void_ptr) -> c_int;
39+
fn af_get_half_support(available: *mut c_int, device: c_int) -> c_int;
3940
}
4041

4142
/// Get ArrayFire Version Number
@@ -331,3 +332,21 @@ pub unsafe fn free_pinned(ptr: void_ptr) {
331332
let err_val = af_free_pinned(ptr);
332333
HANDLE_ERROR(AfError::from(err_val));
333334
}
335+
336+
/// Check if a device has half support
337+
///
338+
/// # Parameters
339+
///
340+
/// - `device` is the device for which half precision support is checked for
341+
///
342+
/// # Return Values
343+
///
344+
/// `True` if `device` device has half support, `False` otherwise.
345+
pub fn is_half_available(device: i32) -> bool {
346+
unsafe {
347+
let mut temp: i32 = 0;
348+
let err_val = af_get_half_support(&mut temp as *mut c_int, device as c_int);
349+
HANDLE_ERROR(AfError::from(err_val));
350+
temp > 0
351+
}
352+
}

0 commit comments

Comments
 (0)