2
2
3
3
from ngclearn import Component , Compartment
4
4
from ngclearn import numpy as np
5
- from ngcsimlib .utils import add_component_resolver , add_resolver_meta , get_current_path
5
+ from ngcsimlib .utils import add_component_resolver , add_resolver_meta , \
6
+ get_current_path
6
7
from ngcsimlib .logger import warn , critical
8
+ import matplotlib .pyplot as plt
7
9
8
10
9
11
class Base_Monitor (Component ):
@@ -21,7 +23,8 @@ class Base_Monitor(Component):
21
23
Using custom window length:
22
24
myMonitor.watch(myComponent.myCompartment, customWindowLength)
23
25
24
- To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All
26
+ To get values out of the monitor either path to the stored value
27
+ directly, or pass in a compartment directly. All
25
28
paths are the same as their local path variable.
26
29
27
30
Using a compartment:
@@ -30,7 +33,8 @@ class Base_Monitor(Component):
30
33
Using a path:
31
34
myMonitor.get_store(myComponent.myCompartment.path).value
32
35
33
- There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers
36
+ There can only be one monitor in existence at a time due to the way it
37
+ interacts with resolvers and the compilers
34
38
for ngclearn.
35
39
36
40
Args:
@@ -53,10 +57,10 @@ def build_advance(compartments):
53
57
54
58
"""
55
59
critical (
56
- "build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or "
60
+ "build_advance() is not defined on this monitor, use either the "
61
+ "monitor found in ngclearn.components or "
57
62
"ngclearn.components.lava (If using lava)" )
58
63
59
-
60
64
@staticmethod
61
65
def build_reset (compartments ):
62
66
"""
@@ -66,6 +70,7 @@ def build_reset(compartments):
66
70
67
71
Returns: The method to reset the stored values.
68
72
"""
73
+
69
74
@staticmethod
70
75
def _reset (** kwargs ):
71
76
return_vals = []
@@ -95,7 +100,8 @@ def __lshift__(self, other):
95
100
96
101
def watch (self , compartment , window_length ):
97
102
"""
98
- Sets the monitor to watch a specific compartment, for a specified window length.
103
+ Sets the monitor to watch a specific compartment, for a specified
104
+ window length.
99
105
100
106
Args:
101
107
compartment: the compartment object to monitor
@@ -150,7 +156,7 @@ def halt_all(self):
150
156
"""
151
157
for compartment in self ._sources :
152
158
self .halt (compartment )
153
-
159
+
154
160
def _update_resolver (self ):
155
161
output_compartments = []
156
162
compartments = []
@@ -162,13 +168,18 @@ def _update_resolver(self):
162
168
parameters = []
163
169
164
170
add_component_resolver (self .__class__ .__name__ , "advance_state" ,
165
- (self .build_advance (compartments ), output_compartments ))
171
+ (self .build_advance (compartments ),
172
+ output_compartments ))
166
173
add_resolver_meta (self .__class__ .__name__ , "advance_state" ,
167
- (args , parameters , compartments + [o for o in output_compartments ], False ))
174
+ (args , parameters ,
175
+ compartments + [o for o in output_compartments ],
176
+ False ))
168
177
169
- add_component_resolver (self .__class__ .__name__ , "reset" , (self .build_reset (compartments ), output_compartments ))
178
+ add_component_resolver (self .__class__ .__name__ , "reset" , (
179
+ self .build_reset (compartments ), output_compartments ))
170
180
add_resolver_meta (self .__class__ .__name__ , "reset" ,
171
- (args , parameters , [o for o in output_compartments ], False ))
181
+ (args , parameters , [o for o in output_compartments ],
182
+ False ))
172
183
173
184
def _add_path (self , path ):
174
185
_path = path .split ("/" )[1 :]
@@ -210,7 +221,8 @@ def save(self, directory, **kwargs):
210
221
for key in self .compartments :
211
222
n = key .split ("/" )[- 1 ]
212
223
_dict ["sources" ][key ] = self .__dict__ [n ].value .shape
213
- _dict ["stores" ][key + "*store" ] = self .__dict__ [n + "*store" ].value .shape
224
+ _dict ["stores" ][key + "*store" ] = self .__dict__ [
225
+ n + "*store" ].value .shape
214
226
215
227
with open (file_name , "w" ) as f :
216
228
json .dump (_dict , f )
@@ -221,9 +233,9 @@ def load(self, directory, **kwargs):
221
233
vals = json .load (f )
222
234
223
235
for comp_path , shape in vals ["stores" ].items ():
224
-
225
236
compartment_path = comp_path .split ("/" )[- 1 ]
226
- new_path = get_current_path () + "/" + "/" .join (compartment_path .split ("*" )[- 3 :- 1 ])
237
+ new_path = get_current_path () + "/" + "/" .join (
238
+ compartment_path .split ("*" )[- 3 :- 1 ])
227
239
228
240
cs , end = self ._add_path (new_path )
229
241
@@ -233,8 +245,6 @@ def load(self, directory, **kwargs):
233
245
cs [end ] = new_comp
234
246
setattr (self , compartment_path , new_comp )
235
247
236
-
237
-
238
248
for comp_path , shape in vals ['sources' ].items ():
239
249
compartment_path = comp_path .split ("/" )[- 1 ]
240
250
new_comp = Compartment (np .zeros (shape ))
@@ -244,3 +254,46 @@ def load(self, directory, **kwargs):
244
254
self .compartments .append (new_comp .path )
245
255
246
256
self ._update_resolver ()
257
+
258
+ def make_plot (self , compartment , ax = None , ylabel = None , xlabel = None , title = None , n = None , plot_func = None ):
259
+ vals = self .view (compartment )
260
+
261
+ if n is None :
262
+ n = vals .shape [2 ]
263
+ if title is None :
264
+ title = compartment .name .split ("/" )[0 ] + " " + compartment .display_name
265
+
266
+ if ylabel is None :
267
+ _ylabel = compartment .units
268
+ elif ylabel :
269
+ _ylabel = ylabel
270
+ else :
271
+ _ylabel = None
272
+
273
+ if xlabel is None :
274
+ _xlabel = "Time Steps"
275
+ elif xlabel :
276
+ _xlabel = xlabel
277
+ else :
278
+ _xlabel = None
279
+
280
+ if ax is None :
281
+ _ax = plt
282
+ _ax .title (title )
283
+ if _ylabel :
284
+ _ax .ylabel (_ylabel )
285
+ if _xlabel :
286
+ _ax .xlabel (_xlabel )
287
+ else :
288
+ _ax = ax
289
+ _ax .set_title (title )
290
+ if _ylabel :
291
+ _ax .set_ylabel (_ylabel )
292
+ if _xlabel :
293
+ _ax .set_xlabel (_xlabel )
294
+
295
+ if plot_func is None :
296
+ for k in range (n ):
297
+ _ax .plot (vals [:, 0 , k ])
298
+ else :
299
+ plot_func (vals , ax = _ax )
0 commit comments