Skip to content

Commit 8db332e

Browse files
committed
Change Callback struct to hold fn() instead of &Fn()
This removed the need to impl unsafe Send & Sync traits for Callback type. Also updated the error handler test to be more rusty.
1 parent e99ea54 commit 8db332e

File tree

2 files changed

+48
-56
lines changed

2 files changed

+48
-56
lines changed

src/error.rs

+31-32
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
11
use std::ops::{Deref, DerefMut};
22
use defines::AfError;
33
use std::error::Error;
4-
use std::marker::{Send, Sync};
54
use std::sync::RwLock;
65

7-
/// Signature of callback function to be called to handle errors
8-
pub type ErrorCallback = Fn(AfError);
6+
/// Signature of error handling callback function
7+
pub type ErrorCallback = fn(AfError);
98

10-
/// Wrap ErrorCallback function pointer inside a structure
11-
/// to enable implementing Send, Sync traits on it.
12-
pub struct Callback<'cblifetime> {
13-
///Reference to a valid error callback function
14-
///Make sure this callback stays relevant throughout the lifetime of application.
15-
pub cb: &'cblifetime ErrorCallback,
9+
/// Structure holding handle to callback function
10+
pub struct Callback {
11+
cb: ErrorCallback,
1612
}
1713

18-
// Implement Send, Sync traits for Callback structure to
19-
// enable the user of Callback function pointer in conjunction
20-
// with threads using a mutex.
21-
unsafe impl<'cblifetime> Send for Callback<'cblifetime> {}
22-
unsafe impl<'cblifetime> Sync for Callback<'cblifetime> {}
14+
impl Callback {
15+
/// Associated function to create a new Callback object
16+
pub fn new(callback: ErrorCallback) -> Self {
17+
Callback {cb: callback}
18+
}
19+
20+
/// call invokes the error callback with `error_code`.
21+
pub fn call(&self, error_code: AfError) {
22+
(self.cb)(error_code)
23+
}
24+
}
2325

24-
pub const DEFAULT_HANDLE_ERROR: Callback<'static> = Callback{cb: &handle_error_general};
26+
/// Default error handling callback provided by ArrayFire crate
27+
pub fn handle_error_general(error_code: AfError) {
28+
match error_code {
29+
AfError::SUCCESS => {}, /* No-op */
30+
_ => panic!("Error message: {}", error_code.description()),
31+
}
32+
}
2533

2634
lazy_static! {
27-
static ref ERROR_HANDLER_LOCK: RwLock< Callback<'static> > =
28-
RwLock::new(DEFAULT_HANDLE_ERROR);
35+
static ref ERROR_HANDLER_LOCK: RwLock< Callback > =
36+
RwLock::new(Callback::new(handle_error_general));
2937
}
3038

3139
/// Register user provided error handler
@@ -45,16 +53,17 @@ lazy_static! {
4553
/// }
4654
/// }
4755
///
48-
/// pub const ERR_HANDLE: Callback<'static> = Callback{ cb: &handleError};
49-
///
5056
/// fn main() {
51-
/// register_error_handler(ERR_HANDLE);
57+
/// //Registering the error handler should be the first call
58+
/// //before any other functions are called if your version
59+
/// //of error is to be used for subsequent function calls
60+
/// register_error_handler(Callback::new(handleError));
5261
///
5362
/// info();
5463
/// }
5564
/// ```
5665
#[allow(unused_must_use)]
57-
pub fn register_error_handler(cb_value: Callback<'static>) {
66+
pub fn register_error_handler(cb_value: Callback) {
5867
let mut gaurd = match ERROR_HANDLER_LOCK.write() {
5968
Ok(g) => g,
6069
Err(_)=> panic!("Failed to acquire lock to register error handler"),
@@ -63,22 +72,12 @@ pub fn register_error_handler(cb_value: Callback<'static>) {
6372
*gaurd.deref_mut() = cb_value;
6473
}
6574

66-
/// Default error handling callback provided by ArrayFire crate
67-
pub fn handle_error_general(error_code: AfError) {
68-
match error_code {
69-
AfError::SUCCESS => {}, /* No-op */
70-
_ => panic!("Error message: {}", error_code.description()),
71-
}
72-
}
73-
7475
#[allow(non_snake_case)]
7576
pub fn HANDLE_ERROR(error_code: AfError) {
7677
let gaurd = match ERROR_HANDLER_LOCK.read() {
7778
Ok(g) => g,
7879
Err(_)=> panic!("Failed to acquire lock while handling FFI return value"),
7980
};
8081

81-
let func = gaurd.deref().cb;
82-
83-
func(error_code);
82+
(*gaurd.deref()).call(error_code);
8483
}

tests/lib.rs

+17-24
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,44 @@ extern crate arrayfire as af;
22

33
use std::error::Error;
44
use std::thread;
5-
use std::time::Duration;
65
use af::*;
76

87
macro_rules! implement_handler {
9-
($fn_name:ident, $msg: expr) => (
10-
8+
($fn_name:ident) => (
119
pub fn $fn_name(error_code: AfError) {
12-
println!("{:?}", $msg);
1310
match error_code {
1411
AfError::SUCCESS => {}, /* No-op */
1512
_ => panic!("Error message: {}", error_code.description()),
1613
}
1714
}
18-
1915
)
2016
}
2117

22-
implement_handler!(handler_sample1, "Error Handler Sample1");
23-
implement_handler!(handler_sample2, "Error Handler Sample2");
24-
implement_handler!(handler_sample3, "Error Handler Sample3");
25-
implement_handler!(handler_sample4, "Error Handler Sample4");
26-
27-
pub const HANDLE1: Callback<'static> = Callback{ cb: &handler_sample1};
28-
pub const HANDLE2: Callback<'static> = Callback{ cb: &handler_sample2};
29-
pub const HANDLE3: Callback<'static> = Callback{ cb: &handler_sample3};
30-
pub const HANDLE4: Callback<'static> = Callback{ cb: &handler_sample4};
18+
implement_handler!(handler_sample1);
19+
implement_handler!(handler_sample2);
20+
implement_handler!(handler_sample3);
21+
implement_handler!(handler_sample4);
3122

3223
#[allow(unused_must_use)]
3324
#[test]
3425
fn check_error_handler_mutation() {
3526

36-
for i in 0..4 {
27+
let children = (0..4).map(|i| {
3728
thread::Builder::new().name(format!("child {}",i+1).to_string()).spawn(move || {
38-
println!("{:?}", thread::current());
29+
let target_device = i%af::device_count();
30+
println!("Thread {:?} 's target device is {}", thread::current(), target_device);
3931
match i {
40-
0 => register_error_handler(HANDLE1),
41-
1 => register_error_handler(HANDLE2),
42-
2 => register_error_handler(HANDLE3),
43-
3 => register_error_handler(HANDLE4),
32+
0 => register_error_handler(Callback::new(handler_sample1)),
33+
1 => register_error_handler(Callback::new(handler_sample2)),
34+
2 => register_error_handler(Callback::new(handler_sample3)),
35+
3 => register_error_handler(Callback::new(handler_sample4)),
4436
_ => panic!("Impossible scenario"),
4537
}
46-
});
47-
}
38+
}).ok().expect("Failed to launch a thread")
39+
}).collect::< Vec<_> >();
4840

49-
af::info();
50-
thread::sleep(Duration::from_millis(50));
41+
for c in children {
42+
c.join();
43+
}
5144

5245
}

0 commit comments

Comments
 (0)