Skip to content

Commit 22ba2d7

Browse files
author
cogmission
committed
Added HTMObjecInput and HTMObjectOutput test
1 parent 6693834 commit 22ba2d7

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package org.numenta.nupic.serialize;
2+
3+
import static org.junit.Assert.assertNotNull;
4+
import static org.junit.Assert.assertTrue;
5+
import static org.junit.Assert.fail;
6+
7+
import java.io.ByteArrayInputStream;
8+
import java.io.ByteArrayOutputStream;
9+
10+
import org.junit.Test;
11+
import org.numenta.nupic.Parameters;
12+
import org.numenta.nupic.Parameters.KEY;
13+
import org.numenta.nupic.algorithms.Anomaly;
14+
import org.numenta.nupic.algorithms.SpatialPooler;
15+
import org.numenta.nupic.algorithms.TemporalMemory;
16+
import org.numenta.nupic.network.Network;
17+
import org.numenta.nupic.network.NetworkTestHarness;
18+
import org.numenta.nupic.network.Persistence;
19+
import org.numenta.nupic.network.PublisherSupplier;
20+
import org.numenta.nupic.network.sensor.ObservableSensor;
21+
import org.numenta.nupic.network.sensor.Sensor;
22+
import org.numenta.nupic.network.sensor.SensorParams;
23+
import org.numenta.nupic.network.sensor.SensorParams.Keys;
24+
import org.numenta.nupic.util.FastRandom;
25+
26+
27+
public class HTMObjectInputOutputTest {
28+
29+
@Test
30+
public void testRoundTrip() {
31+
Network network = getLoadedHotGymNetwork();
32+
SerializerCore serializer = Persistence.get().serializer();
33+
ByteArrayOutputStream baos = new ByteArrayOutputStream();
34+
HTMObjectOutput writer = serializer.getObjectOutput(baos);
35+
try {
36+
writer.writeObject(network, Network.class);
37+
writer.flush();
38+
writer.close();
39+
}catch(Exception e) {
40+
fail();
41+
}
42+
43+
byte[] bytes = baos.toByteArray();
44+
45+
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
46+
try {
47+
HTMObjectInput reader = serializer.getObjectInput(bais);
48+
Network serializedNetwork = (Network)reader.readObject(Network.class);
49+
assertNotNull(serializedNetwork);
50+
assertTrue(serializedNetwork.equals(network));
51+
}catch(Exception e) {
52+
e.printStackTrace();
53+
fail();
54+
}
55+
}
56+
57+
private Network getLoadedHotGymNetwork() {
58+
Parameters p = NetworkTestHarness.getParameters().copy();
59+
p = p.union(NetworkTestHarness.getHotGymTestEncoderParams());
60+
p.setParameterByKey(KEY.RANDOM, new FastRandom(42));
61+
62+
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
63+
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
64+
PublisherSupplier.builder()
65+
.addHeader("timestamp, consumption")
66+
.addHeader("datetime, float")
67+
.addHeader("B").build() }));
68+
69+
Network network = Network.create("test network", p).add(Network.createRegion("r1")
70+
.add(Network.createLayer("1", p)
71+
.alterParameter(KEY.AUTO_CLASSIFY, true)
72+
.add(Anomaly.create())
73+
.add(new TemporalMemory())
74+
.add(new SpatialPooler())
75+
.add(sensor)));
76+
77+
return network;
78+
}
79+
}

0 commit comments

Comments
 (0)