Skip to content

Commit e240644

Browse files
authored
Monitor plot (#66)
* Update base_monitor.py * added plotting viewed compartments
1 parent 1ddd86d commit e240644

File tree

1 file changed

+69
-16
lines changed

1 file changed

+69
-16
lines changed

ngclearn/components/base_monitor.py

+69-16
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from ngclearn import Component, Compartment
44
from ngclearn import numpy as np
5-
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path
5+
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
6+
get_current_path
67
from ngcsimlib.logger import warn, critical
8+
import matplotlib.pyplot as plt
79

810

911
class Base_Monitor(Component):
@@ -21,7 +23,8 @@ class Base_Monitor(Component):
2123
Using custom window length:
2224
myMonitor.watch(myComponent.myCompartment, customWindowLength)
2325
24-
To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All
26+
To get values out of the monitor either path to the stored value
27+
directly, or pass in a compartment directly. All
2528
paths are the same as their local path variable.
2629
2730
Using a compartment:
@@ -30,7 +33,8 @@ class Base_Monitor(Component):
3033
Using a path:
3134
myMonitor.get_store(myComponent.myCompartment.path).value
3235
33-
There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers
36+
There can only be one monitor in existence at a time due to the way it
37+
interacts with resolvers and the compilers
3438
for ngclearn.
3539
3640
Args:
@@ -53,10 +57,10 @@ def build_advance(compartments):
5357
5458
"""
5559
critical(
56-
"build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or "
60+
"build_advance() is not defined on this monitor, use either the "
61+
"monitor found in ngclearn.components or "
5762
"ngclearn.components.lava (If using lava)")
5863

59-
6064
@staticmethod
6165
def build_reset(compartments):
6266
"""
@@ -66,6 +70,7 @@ def build_reset(compartments):
6670
6771
Returns: The method to reset the stored values.
6872
"""
73+
6974
@staticmethod
7075
def _reset(**kwargs):
7176
return_vals = []
@@ -95,7 +100,8 @@ def __lshift__(self, other):
95100

96101
def watch(self, compartment, window_length):
97102
"""
98-
Sets the monitor to watch a specific compartment, for a specified window length.
103+
Sets the monitor to watch a specific compartment, for a specified
104+
window length.
99105
100106
Args:
101107
compartment: the compartment object to monitor
@@ -150,7 +156,7 @@ def halt_all(self):
150156
"""
151157
for compartment in self._sources:
152158
self.halt(compartment)
153-
159+
154160
def _update_resolver(self):
155161
output_compartments = []
156162
compartments = []
@@ -162,13 +168,18 @@ def _update_resolver(self):
162168
parameters = []
163169

164170
add_component_resolver(self.__class__.__name__, "advance_state",
165-
(self.build_advance(compartments), output_compartments))
171+
(self.build_advance(compartments),
172+
output_compartments))
166173
add_resolver_meta(self.__class__.__name__, "advance_state",
167-
(args, parameters, compartments + [o for o in output_compartments], False))
174+
(args, parameters,
175+
compartments + [o for o in output_compartments],
176+
False))
168177

169-
add_component_resolver(self.__class__.__name__, "reset", (self.build_reset(compartments), output_compartments))
178+
add_component_resolver(self.__class__.__name__, "reset", (
179+
self.build_reset(compartments), output_compartments))
170180
add_resolver_meta(self.__class__.__name__, "reset",
171-
(args, parameters, [o for o in output_compartments], False))
181+
(args, parameters, [o for o in output_compartments],
182+
False))
172183

173184
def _add_path(self, path):
174185
_path = path.split("/")[1:]
@@ -210,7 +221,8 @@ def save(self, directory, **kwargs):
210221
for key in self.compartments:
211222
n = key.split("/")[-1]
212223
_dict["sources"][key] = self.__dict__[n].value.shape
213-
_dict["stores"][key + "*store"] = self.__dict__[n + "*store"].value.shape
224+
_dict["stores"][key + "*store"] = self.__dict__[
225+
n + "*store"].value.shape
214226

215227
with open(file_name, "w") as f:
216228
json.dump(_dict, f)
@@ -221,9 +233,9 @@ def load(self, directory, **kwargs):
221233
vals = json.load(f)
222234

223235
for comp_path, shape in vals["stores"].items():
224-
225236
compartment_path = comp_path.split("/")[-1]
226-
new_path = get_current_path() + "/" + "/".join(compartment_path.split("*")[-3:-1])
237+
new_path = get_current_path() + "/" + "/".join(
238+
compartment_path.split("*")[-3:-1])
227239

228240
cs, end = self._add_path(new_path)
229241

@@ -233,8 +245,6 @@ def load(self, directory, **kwargs):
233245
cs[end] = new_comp
234246
setattr(self, compartment_path, new_comp)
235247

236-
237-
238248
for comp_path, shape in vals['sources'].items():
239249
compartment_path = comp_path.split("/")[-1]
240250
new_comp = Compartment(np.zeros(shape))
@@ -244,3 +254,46 @@ def load(self, directory, **kwargs):
244254
self.compartments.append(new_comp.path)
245255

246256
self._update_resolver()
257+
258+
def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
259+
vals = self.view(compartment)
260+
261+
if n is None:
262+
n = vals.shape[2]
263+
if title is None:
264+
title = compartment.name.split("/")[0] + " " + compartment.display_name
265+
266+
if ylabel is None:
267+
_ylabel = compartment.units
268+
elif ylabel:
269+
_ylabel = ylabel
270+
else:
271+
_ylabel = None
272+
273+
if xlabel is None:
274+
_xlabel = "Time Steps"
275+
elif xlabel:
276+
_xlabel = xlabel
277+
else:
278+
_xlabel = None
279+
280+
if ax is None:
281+
_ax = plt
282+
_ax.title(title)
283+
if _ylabel:
284+
_ax.ylabel(_ylabel)
285+
if _xlabel:
286+
_ax.xlabel(_xlabel)
287+
else:
288+
_ax = ax
289+
_ax.set_title(title)
290+
if _ylabel:
291+
_ax.set_ylabel(_ylabel)
292+
if _xlabel:
293+
_ax.set_xlabel(_xlabel)
294+
295+
if plot_func is None:
296+
for k in range(n):
297+
_ax.plot(vals[:, 0, k])
298+
else:
299+
plot_func(vals, ax=_ax)

0 commit comments

Comments
 (0)