Skip to content

Commit 77f347f

Browse files
committed
integrated phasor-cell, minor cleanup of latency
1 parent bf06510 commit 77f347f

File tree

4 files changed

+183
-2
lines changed

4 files changed

+183
-2
lines changed

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .input_encoders.bernoulliCell import BernoulliCell
2222
from .input_encoders.poissonCell import PoissonCell
2323
from .input_encoders.latencyCell import LatencyCell
24+
from .input_encoders.phasorCell import PhasorCell
2425
## point to synapse component types
2526
from .synapses.denseSynapse import DenseSynapse
2627
from .synapses.staticSynapse import StaticSynapse
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .bernoulliCell import BernoulliCell
22
from .poissonCell import PoissonCell
33
from .latencyCell import LatencyCell
4+
from .phasorCell import PhasorCell

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1.,
4848
projected spike times
4949
"""
5050
_tau = tau
51-
if normalize == True:
51+
if normalize:
5252
_tau = num_steps - 1. - first_spk_t ## linear normalization
5353
#torch.clamp_max((-tau * (data - 1)), -tau * (threshold - 1))
5454
stimes = -_tau * (data - 1.) ## calc raw latency code values
@@ -85,7 +85,7 @@ def _calc_spike_times_nonlinear(data, tau, thr, first_spk_t, eps=1e-7,
8585
stimes = jnp.log(_data / (_data - thr)) * tau ## calc spike times
8686
stimes = stimes + first_spk_t
8787

88-
if normalize == True:
88+
if normalize:
8989
term1 = (stimes - first_spk_t)
9090
term2 = (num_steps - first_spk_t - 1.)
9191
term3 = jnp.max(stimes - first_spk_t)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from ngclearn import resolver, Compartment
2+
from ngclearn.components.jaxComponent import JaxComponent
3+
from ngclearn.utils import tensorstats
4+
from jax import numpy as jnp, random
5+
from ngcsimlib.logger import warn
6+
7+
class PhasorCell(JaxComponent):
8+
"""
9+
A phasor cell that emits a pulse at a regular interval.
10+
11+
| --- Cell Input Compartments: ---
12+
| inputs - input (takes in external signals)
13+
| --- Cell State Compartments: ---
14+
| key - JAX PRNG key
15+
| --- Cell Output Compartments: ---
16+
| outputs - output
17+
| tols - time-of-last-spike
18+
19+
Args:
20+
name: the string name of this cell
21+
22+
n_units: number of cellular entities (neural population size)
23+
24+
target_freq: maximum frequency (in Hertz) of this spike train
25+
(must be > 0.)
26+
"""
27+
28+
# Define Functions
29+
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
30+
**kwargs):
31+
super().__init__(name, **kwargs)
32+
33+
## Phasor meta-parameters
34+
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
35+
36+
## Layer Size Setup
37+
self.batch_size = batch_size
38+
self.n_units = n_units
39+
_key, subkey = random.split(self.key.value, 2)
40+
self.key.set(_key)
41+
## Compartment setup
42+
restVals = jnp.zeros((self.batch_size, self.n_units))
43+
self.inputs = Compartment(restVals,
44+
display_name="Input Stimulus") # input
45+
# compartment
46+
self.outputs = Compartment(restVals,
47+
display_name="Spikes") # output compartment
48+
self.tols = Compartment(initial_value=restVals,
49+
display_name="Time-of-Last-Spike", units="ms") # time of last spike
50+
self.angles = Compartment(restVals, display_name="Angles", units="deg")
51+
# self.base_scale = random.uniform(subkey, self.angles.value.shape,
52+
# minval=0.75, maxval=1.25)
53+
# self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1)
54+
# alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
55+
# beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
56+
57+
self.base_scale = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
58+
59+
def validate(self, dt=None, **validation_kwargs):
60+
valid = super().validate(**validation_kwargs)
61+
if dt is None:
62+
warn(f"{self.name} requires a validation kwarg of `dt`")
63+
return False
64+
## check for unstable combinations of dt and target-frequency
65+
# meta-params
66+
events_per_timestep = (
67+
dt / 1000.) * self.target_freq ##
68+
# compute scaled probability
69+
if events_per_timestep > 1.:
70+
valid = False
71+
warn(
72+
f"{self.name} will be unable to make as many temporal events "
73+
f"as "
74+
f"requested! ({events_per_timestep} events/timestep) Unstable "
75+
f"combination of dt = {dt} and target_freq = "
76+
f"{self.target_freq} "
77+
f"being used!"
78+
)
79+
return valid
80+
81+
@staticmethod
82+
def _advance_state(t, dt, target_freq, key,
83+
inputs, angles, tols, base_scale):
84+
ms_per_second = 1000 # ms/s
85+
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
86+
ms_per_event = 1 / events_per_ms # ms/e
87+
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
88+
angle_per_event = 2 * jnp.pi # rad / e
89+
angle_per_timestep = angle_per_event / time_step_per_event # rad / e
90+
# * e/ts -> rad / ts
91+
key, subkey = random.split(key, 2)
92+
# scatter = random.uniform(subkey, angles.shape, minval=0.5,
93+
# maxval=1.5) * base_scale
94+
95+
scatter = ((random.normal(subkey, angles.shape) * 0.2) + 1) * base_scale
96+
scattered_update = angle_per_timestep * scatter
97+
scaled_scattered_update = scattered_update * inputs
98+
99+
updated_angles = angles + scaled_scattered_update
100+
outputs = jnp.where(updated_angles > angle_per_event, 1., 0.)
101+
updated_angles = jnp.where(updated_angles > angle_per_event,
102+
updated_angles - angle_per_event,
103+
updated_angles)
104+
tols = tols * (1. - outputs) + t * outputs
105+
106+
return outputs, tols, key, updated_angles
107+
108+
@resolver(_advance_state)
109+
def advance_state(self, outputs, tols, key, angles):
110+
self.outputs.set(outputs)
111+
self.tols.set(tols)
112+
self.key.set(key)
113+
self.angles.set(angles)
114+
115+
@staticmethod
116+
def _reset(batch_size, n_units, key, target_freq):
117+
restVals = jnp.zeros((batch_size, n_units))
118+
key, subkey = random.split(key, 2)
119+
return restVals, restVals, restVals, restVals, key
120+
121+
@resolver(_reset)
122+
def reset(self, inputs, outputs, tols, angles, key):
123+
self.inputs.set(inputs)
124+
self.outputs.set(outputs)
125+
self.tols.set(tols)
126+
self.key.set(key)
127+
self.angles.set(angles)
128+
129+
def save(self, directory, **kwargs):
130+
file_name = directory + "/" + self.name + ".npz"
131+
jnp.savez(file_name, key=self.key.value)
132+
133+
def load(self, directory, **kwargs):
134+
file_name = directory + "/" + self.name + ".npz"
135+
data = jnp.load(file_name)
136+
self.key.set(data['key'])
137+
138+
@classmethod
139+
def help(cls): ## component help function
140+
properties = {
141+
"cell_type": "Phasor - Produces input at a fairly regular "
142+
"intervals with small amounts of noise)"
143+
}
144+
compartment_props = {
145+
"inputs":
146+
{"inputs": "Takes in external input signal values"},
147+
"states":
148+
{"key": "JAX PRNG key",
149+
"angles": "The current angle of the phasor"},
150+
"outputs":
151+
{"tols": "Time-of-last-spike",
152+
"outputs": "Binary spike values emitted at time t"},
153+
}
154+
hyperparams = {
155+
"n_units": "Number of neuronal cells to model in this layer",
156+
"batch_size": "Batch size dimension of this component",
157+
"target_freq": "Maximum spike frequency of the train produced",
158+
}
159+
info = {cls.__name__: properties,
160+
"compartments": compartment_props,
161+
"hyperparameters": hyperparams}
162+
return info
163+
164+
def __repr__(self):
165+
comps = [varname for varname in dir(self) if
166+
Compartment.is_compartment(getattr(self, varname))]
167+
maxlen = max(len(c) for c in comps) + 5
168+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
169+
for c in comps:
170+
stats = tensorstats(getattr(self, c).value)
171+
if stats is not None:
172+
line = [f"{k}: {v}" for k, v in stats.items()]
173+
line = ", ".join(line)
174+
else:
175+
line = "None"
176+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
177+
return lines
178+
179+

0 commit comments

Comments
 (0)