Skip to content

Commit 62de25b

Browse files
morenolMec-iS
andcommitted
Handle kernel serialization (#232)
* Handle kernel serialization * Do not use typetag in WASM * enable tests for serialization * Update serde feature deps Co-authored-by: Luis Moreno <[email protected]> Co-authored-by: Lorenzo <[email protected]>
1 parent 7d87451 commit 62de25b

File tree

4 files changed

+30
-50
lines changed

4 files changed

+30
-50
lines changed

Cargo.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
2929
rand_distr = { version = "0.4", optional = true }
3030
serde = { version = "1", features = ["derive"], optional = true }
3131

32+
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
33+
typetag = { version = "0.2", optional = true }
34+
3235
[features]
3336
default = []
34-
serde = ["dep:serde"]
37+
serde = ["dep:serde", "dep:typetag"]
3538
ndarray-bindings = ["dep:ndarray"]
3639
datasets = ["dep:rand_distr", "std_rand", "serde"]
3740
std_rand = ["rand/std_rng", "rand/std"]

src/svm/mod.rs

+10-36
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ pub mod svr;
3030

3131
use core::fmt::Debug;
3232

33-
#[cfg(feature = "serde")]
34-
use serde::ser::{SerializeStruct, Serializer};
3533
#[cfg(feature = "serde")]
3634
use serde::{Deserialize, Serialize};
3735

@@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
4038

4139
/// Defines a kernel function.
4240
/// This is a object-safe trait.
43-
pub trait Kernel {
41+
#[cfg_attr(
42+
all(feature = "serde", not(target_arch = "wasm32")),
43+
typetag::serde(tag = "type")
44+
)]
45+
pub trait Kernel: Debug {
4446
#[allow(clippy::ptr_arg)]
4547
/// Apply kernel function to x_i and x_j
4648
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
47-
/// Return a serializable name
48-
fn name(&self) -> &'static str;
49-
}
50-
51-
impl Debug for dyn Kernel {
52-
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53-
write!(f, "Kernel<f64>")
54-
}
55-
}
56-
57-
#[cfg(feature = "serde")]
58-
impl Serialize for dyn Kernel {
59-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60-
where
61-
S: Serializer,
62-
{
63-
let mut s = serializer.serialize_struct("Kernel", 1)?;
64-
s.serialize_field("type", &self.name())?;
65-
s.end()
66-
}
6749
}
6850

6951
/// Pre-defined kernel functions
7052
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7153
#[derive(Debug, Clone)]
72-
pub struct Kernels {}
54+
pub struct Kernels;
7355

7456
impl Kernels {
7557
/// Return a default linear
@@ -211,15 +193,14 @@ impl SigmoidKernel {
211193
}
212194
}
213195

196+
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
214197
impl Kernel for LinearKernel {
215198
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
216199
Ok(x_i.dot(x_j))
217200
}
218-
fn name(&self) -> &'static str {
219-
"Linear"
220-
}
221201
}
222202

203+
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
223204
impl Kernel for RBFKernel {
224205
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
225206
if self.gamma.is_none() {
@@ -231,11 +212,9 @@ impl Kernel for RBFKernel {
231212
let v_diff = x_i.sub(x_j);
232213
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
233214
}
234-
fn name(&self) -> &'static str {
235-
"RBF"
236-
}
237215
}
238216

217+
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
239218
impl Kernel for PolynomialKernel {
240219
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
241220
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
@@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel {
247226
let dot = x_i.dot(x_j);
248227
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
249228
}
250-
fn name(&self) -> &'static str {
251-
"Polynomial"
252-
}
253229
}
254230

231+
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
255232
impl Kernel for SigmoidKernel {
256233
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
257234
if self.gamma.is_none() || self.coef0.is_none() {
@@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel {
263240
let dot = x_i.dot(x_j);
264241
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
265242
}
266-
fn name(&self) -> &'static str {
267-
"Sigmoid"
268-
}
269243
}
270244

271245
#[cfg(test)]

src/svm/svc.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,11 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
100100
pub c: TX,
101101
/// Tolerance for stopping criterion.
102102
pub tol: TX,
103-
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
104103
/// The kernel function.
104+
#[cfg_attr(
105+
all(feature = "serde", target_arch = "wasm32"),
106+
serde(skip_serializing, skip_deserializing)
107+
)]
105108
pub kernel: Option<Box<dyn Kernel>>,
106109
/// Unused parameter.
107110
m: PhantomData<(X, Y, TY)>,
@@ -1085,7 +1088,7 @@ mod tests {
10851088
wasm_bindgen_test::wasm_bindgen_test
10861089
)]
10871090
#[test]
1088-
#[cfg(feature = "serde")]
1091+
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
10891092
fn svc_serde() {
10901093
let x = DenseMatrix::from_2d_array(&[
10911094
&[5.1, 3.5, 1.4, 0.2],
@@ -1119,8 +1122,9 @@ mod tests {
11191122
let svc = SVC::fit(&x, &y, &params).unwrap();
11201123

11211124
// serialization
1122-
let serialized_svc = &serde_json::to_string(&svc).unwrap();
1125+
let deserialized_svc: SVC<f64, i32, _, _> =
1126+
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
11231127

1124-
println!("{:?}", serialized_svc);
1128+
assert_eq!(svc, deserialized_svc);
11251129
}
11261130
}

src/svm/svr.rs

+8-9
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,11 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
9292
pub c: T,
9393
/// Tolerance for stopping criterion.
9494
pub tol: T,
95-
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
9695
/// The kernel function.
96+
#[cfg_attr(
97+
all(feature = "serde", target_arch = "wasm32"),
98+
serde(skip_serializing, skip_deserializing)
99+
)]
97100
pub kernel: Option<Box<dyn Kernel>>,
98101
}
99102

@@ -668,7 +671,7 @@ mod tests {
668671
wasm_bindgen_test::wasm_bindgen_test
669672
)]
670673
#[test]
671-
#[cfg(feature = "serde")]
674+
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
672675
fn svr_serde() {
673676
let x = DenseMatrix::from_2d_array(&[
674677
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
@@ -699,13 +702,9 @@ mod tests {
699702

700703
let svr = SVR::fit(&x, &y, &params).unwrap();
701704

702-
let serialized = &serde_json::to_string(&svr).unwrap();
703-
704-
println!("{}", &serialized);
705-
706-
// let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
707-
// serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
705+
let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
706+
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
708707

709-
// assert_eq!(svr, deserialized_svr);
708+
assert_eq!(svr, deserialized_svr);
710709
}
711710
}

0 commit comments

Comments
 (0)