Skip to content

Commit

Permalink
Merge pull request #32 from maciejkula/fix/serialization
Browse files Browse the repository at this point in the history
Fix EncodableRng JSON serialization.
  • Loading branch information
maciejkula authored Aug 27, 2016
2 parents b2d4201 + e1054c8 commit 75512db
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rustlearn"
version = "0.4.1"
version = "0.4.2"
description = "A machine learning package for Rust."
documentation = "https://maciejkula.github.io/rustlearn/doc/rustlearn/"
homepage = "https://github.com/maciejkula/rustlearn"
Expand All @@ -19,14 +19,14 @@ build = "build.rs"

[dependencies]
rand = "0.3"
rustc-serialize = "0.3"
rustc-serialize = "0.3.16"
crossbeam = "0.2.9"

[build-dependencies]
gcc = "0.3"

[dev-dependencies]
bincode = "0.4.1"
bincode = "0.6.0"
csv = "0.14"
hyper = "0.7.0"
time = "0.1"
Expand Down
6 changes: 5 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Changelog

## [0.4.1][unreleased]
## [0.4.2][unreleased]
### Fixed
- fixed EncodableRng dummy serialization implementation

## [0.4.1][2016-08-26]
### Fixed
- panic when removing constant features when splitting
a decision tree
Expand Down
12 changes: 12 additions & 0 deletions src/trees/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,8 @@ mod tests {

use rand::{StdRng, SeedableRng};

use rustc_serialize::json;

use bincode;

#[cfg(feature = "all_tests")]
Expand Down Expand Up @@ -1174,6 +1176,7 @@ mod tests {

model.fit(&x_train, &y_train).unwrap();

// Binary encoding
let encoded = bincode::rustc_serialize::encode(&model, bincode::SizeLimit::Infinite)
.unwrap();
let decoded: OneVsRestWrapper<DecisionTree> =
Expand All @@ -1182,6 +1185,15 @@ mod tests {
let test_prediction = decoded.predict(&x_test).unwrap();

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);

// JSON encoding
let encoded = json::encode(&model).unwrap();
let decoded: OneVsRestWrapper<DecisionTree> =
json::decode(&encoded).unwrap();

let test_prediction = decoded.predict(&x_test).unwrap();

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);
}

test_accuracy /= no_splits as f32;
Expand Down
23 changes: 21 additions & 2 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use prelude::*;
/// Wrapper for making random number generators serializable.
/// Does no actual encoding, and merely creates a new
/// generator on decoding.
/// This is because rand generators do not expose internal state.
#[derive(Clone)]
pub struct EncodableRng {
pub rng: StdRng,
Expand All @@ -33,14 +34,16 @@ impl Default for EncodableRng {


impl Encodable for EncodableRng {
fn encode<S: Encoder>(&self, _: &mut S) -> Result<(), S::Error> {
fn encode<S: Encoder>(&self, s: &mut S) -> Result<(), S::Error> {
try!(s.emit_struct("EncodableRng", 0, |_| { Ok(()) }));
Ok(())
}
}


impl Decodable for EncodableRng {
fn decode<D: Decoder>(_: &mut D) -> Result<Self, D::Error> {
fn decode<D: Decoder>(d: &mut D) -> Result<Self, D::Error> {
try!(d.read_struct("", 0, |_| { Ok(()) }));
Ok((EncodableRng::new()))
}
}
Expand Down Expand Up @@ -81,3 +84,19 @@ pub fn check_matched_dimensions<T: IndexableMatrix>(X: &T, y: &Array) -> Result<
Err("Data matrix and target array do not have the same number of rows")
}
}


#[cfg(test)]
mod tests {
use super::EncodableRng;

use rustc_serialize::json;

#[test]
fn test_encodable_rng_serialization() {
let rng = EncodableRng::new();

let serialized = json::encode(&rng).unwrap();
let _: EncodableRng = json::decode(&serialized).unwrap();
}
}

0 comments on commit 75512db

Please sign in to comment.