-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathbackend.rs
67 lines (59 loc) · 1.84 KB
/
backend.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use super::defines::{AfError, Backend};
use super::error::HANDLE_ERROR;
use libc::{c_int, c_uint};
extern "C" {
fn af_set_backend(bknd: u8) -> c_int;
fn af_get_backend_count(num_backends: *mut c_uint) -> c_int;
fn af_get_available_backends(backends: *mut c_int) -> c_int;
fn af_get_active_backend(backend: *mut c_int) -> c_int;
}
/// Toggle backends between cuda, opencl or cpu
///
/// # Parameters
///
/// - `backend` to which to switch to
pub fn set_backend(backend: Backend) {
let err_val = unsafe { af_set_backend(backend as u8) };
HANDLE_ERROR(AfError::from(err_val));
}
/// Get the available backend count
pub fn get_backend_count() -> u32 {
let mut temp: u32 = 0;
let err_val = unsafe { af_get_backend_count(&mut temp as *mut c_uint) };
HANDLE_ERROR(AfError::from(err_val));
temp
}
/// Get the available backends
pub fn get_available_backends() -> Vec<Backend> {
let mut temp: i32 = 0;
let err_val = unsafe { af_get_available_backends(&mut temp as *mut c_int) };
HANDLE_ERROR(AfError::from(err_val));
let mut b = Vec::new();
if temp & 0b1000 == 0b1000 {
b.push(Backend::ONEAPI);
}
if temp & 0b0100 == 0b0100 {
b.push(Backend::OPENCL);
}
if temp & 0b0010 == 0b0010 {
b.push(Backend::CUDA);
}
if temp & 0b0001 == 0b0001 {
b.push(Backend::CPU);
}
b
}
/// Get current active backend
pub fn get_active_backend() -> Backend {
let mut temp: i32 = 0;
let err_val = unsafe { af_get_active_backend(&mut temp as *mut c_int) };
HANDLE_ERROR(AfError::from(err_val));
match (err_val, temp) {
(0, 0) => Backend::DEFAULT,
(0, 1) => Backend::CPU,
(0, 2) => Backend::CUDA,
(0, 4) => Backend::OPENCL,
(0, 8) => Backend::ONEAPI,
_ => panic!("Invalid backend retrieved, undefined behavior."),
}
}