-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathbase_monitor.py
313 lines (247 loc) · 10.1 KB
/
base_monitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import json
from ngclearn import Component, Compartment
from ngclearn import numpy as np
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
get_current_path
from ngcsimlib.logger import warn, critical
import matplotlib.pyplot as plt
class Base_Monitor(Component):
"""
An abstract base for monitors for both ngclearn and ngclava. Compartments
wired directly into this component will have their value tracked during
`advance_state` loops automatically.
Note the monitor only works for compiled methods currently
Using default window length:
myMonitor << myComponent.myCompartment
Using custom window length:
myMonitor.watch(myComponent.myCompartment, customWindowLength)
To get values out of the monitor either path to the stored value
directly, or pass in a compartment directly. All
paths are the same as their local path variable.
Using a compartment:
myMonitor.view(myComponent.myCompartment)
Using a path:
myMonitor.get_store(myComponent.myCompartment.path).value
There can only be one monitor in existence at a time due to the way it
interacts with resolvers and the compilers
for ngclearn.
Args:
name: The name of the component.
default_window_length: The default window length.
"""
auto_resolve = False
@staticmethod
def build_advance(compartments):
"""
A method to build the method to advance the stored values.
Args:
compartments: A list of compartments to store values
Returns: The method to advance the stored values.
"""
critical(
"build_advance() is not defined on this monitor, use either the "
"monitor found in ngclearn.components or "
"ngclearn.components.lava (If using lava)")
@staticmethod
def build_reset(component):
"""
A method to build the method to reset the stored values.
Args:
component: The component to resolve
Returns: the reset resolver
"""
output_compartments = []
compartments = []
for comp in component.compartments:
output_compartments.append(comp.split("/")[-1] + "*store")
compartments.append(comp.split("/")[-1])
@staticmethod
def _reset(**kwargs):
return_vals = []
for comp in compartments:
current_store = kwargs[comp + "*store"]
return_vals.append(np.zeros(current_store.shape))
return return_vals if len(compartments) > 1 else return_vals[0]
# pure func, output compartments, args, params, input compartments
return _reset, output_compartments, [], [], output_compartments
@staticmethod
def build_advance_state(component):
output_compartments = []
compartments = []
for comp in component.compartments:
output_compartments.append(comp.split("/")[-1] + "*store")
compartments.append(comp.split("/")[-1])
_advance = component.build_advance(compartments)
return _advance, output_compartments, [], [], compartments + output_compartments
def __init__(self, name, default_window_length=100, **kwargs):
super().__init__(name, **kwargs)
self.store = {}
self.compartments = []
self._sources = []
self.default_window_length = default_window_length
def __lshift__(self, other):
if isinstance(other, Compartment):
self.watch(other, self.default_window_length)
else:
warn("Only Compartments can be monitored not", type(other))
def watch(self, compartment, window_length):
"""
Sets the monitor to watch a specific compartment, for a specified
window length.
Args:
compartment: the compartment object to monitor
window_length: the window length
"""
cs, end = self._add_path(compartment.path)
shape = compartment.value.shape
new_comp = Compartment(np.zeros(shape))
new_comp_store = Compartment(np.zeros((window_length, *shape)))
comp_key = "*".join(compartment.path.split("/"))
store_comp_key = comp_key + "*store"
new_comp._setup(self, comp_key)
new_comp_store._setup(self, store_comp_key)
new_comp << compartment
cs[end] = new_comp_store
setattr(self, comp_key, new_comp)
setattr(self, store_comp_key, new_comp_store)
self.compartments.append(new_comp.path)
self._sources.append(compartment)
# self._update_resolver()
def halt(self, compartment):
"""
Stops the monitor from watching a specific compartment. It is important
to note that it does not stop previously compiled methods. It does not
remove it from the stored values, so it can still be viewed.
Args:
compartment: The compartment object to stop watching
"""
if compartment not in self._sources:
return
comp_key = "*".join(compartment.path.split("/"))
store_comp_key = comp_key + "*store"
self.compartments.remove(getattr(self, comp_key).path)
self._sources.remove(compartment)
delattr(self, comp_key)
delattr(self, store_comp_key)
self._update_resolver()
def halt_all(self):
"""
Stops the monitor from watching all compartments.
"""
for compartment in self._sources:
self.halt(compartment)
# def _update_resolver(self):
# output_compartments = []
# compartments = []
# for comp in self.compartments:
# output_compartments.append(comp.split("/")[-1] + "*store")
# compartments.append(comp.split("/")[-1])
#
# args = []
# parameters = []
#
# add_component_resolver(self.__class__.__name__, "advance_state",
# (self.build_advance(compartments),
# output_compartments))
# add_resolver_meta(self.__class__.__name__, "advance_state",
# (args, parameters,
# compartments + [o for o in output_compartments],
# False))
# add_component_resolver(self.__class__.__name__, "reset", (
# self.build_reset(compartments), output_compartments))
# add_resolver_meta(self.__class__.__name__, "reset",
# (args, parameters, [o for o in output_compartments],
# False))
def _add_path(self, path):
_path = path.split("/")[1:]
end = _path.pop(-1)
current_store = self.store
for p in _path:
if p not in current_store.keys():
current_store[p] = {}
current_store = current_store[p]
return current_store, end
def view(self, compartment):
"""
Gets the value associated with the specified compartment
Args:
compartment: The compartment to extract the stored value of
Returns: The stored value, None if not monitoring that compartment
"""
_path = compartment.path.split("/")[1:]
store = self.get_store(_path)
return store.value if store is not None else store
def get_store(self, path):
current_store = self.store
for p in path:
if p not in current_store.keys():
return None
current_store = current_store[p]
return current_store
def save(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".json"
_dict = {"sources": {}, "stores": {}}
for key in self.compartments:
n = key.split("/")[-1]
_dict["sources"][key] = self.__dict__[n].value.shape
_dict["stores"][key + "*store"] = self.__dict__[
n + "*store"].value.shape
with open(file_name, "w") as f:
json.dump(_dict, f)
def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".json"
with open(file_name, "r") as f:
vals = json.load(f)
for comp_path, shape in vals["stores"].items():
compartment_path = comp_path.split("/")[-1]
new_path = get_current_path() + "/" + "/".join(
compartment_path.split("*")[-3:-1])
cs, end = self._add_path(new_path)
new_comp = Compartment(np.zeros(shape))
new_comp._setup(self, compartment_path)
cs[end] = new_comp
setattr(self, compartment_path, new_comp)
for comp_path, shape in vals['sources'].items():
compartment_path = comp_path.split("/")[-1]
new_comp = Compartment(np.zeros(shape))
new_comp._setup(self, compartment_path)
setattr(self, compartment_path, new_comp)
self.compartments.append(new_comp.path)
self._update_resolver()
def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
vals = self.view(compartment)
if n is None:
n = vals.shape[2]
if title is None:
title = compartment.name.split("/")[0] + " " + compartment.display_name
if ylabel is None:
_ylabel = compartment.units
elif ylabel:
_ylabel = ylabel
else:
_ylabel = None
if xlabel is None:
_xlabel = "Time Steps"
elif xlabel:
_xlabel = xlabel
else:
_xlabel = None
if ax is None:
_ax = plt
_ax.title(title)
if _ylabel:
_ax.ylabel(_ylabel)
if _xlabel:
_ax.xlabel(_xlabel)
else:
_ax = ax
_ax.set_title(title)
if _ylabel:
_ax.set_ylabel(_ylabel)
if _xlabel:
_ax.set_xlabel(_xlabel)
if plot_func is None:
for k in range(n):
_ax.plot(vals[:, 0, k])
else:
plot_func(vals, ax=_ax)