Skip to content

Commit 05a97f0

Browse files
committed
updated the poissonCell to be a true poisson
1 parent 05ea912 commit 05a97f0

File tree

1 file changed

+61
-64
lines changed

1 file changed

+61
-64
lines changed

Diff for: ngclearn/components/input_encoders/poissonCell.py

+61-64
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,15 @@
11
from ngclearn import resolver, Component, Compartment
22
from ngclearn.components.jaxComponent import JaxComponent
33
from ngclearn.utils import tensorstats
4-
from jax import numpy as jnp, random, jit
4+
from jax import numpy as jnp, random, jit, scipy
55
from functools import partial
6+
from ngcsimlib.deprecators import deprecate_args
67

7-
@jit
8-
def _update_times(t, s, tols):
9-
"""
10-
Updates time-of-last-spike (tols) variable.
11-
12-
Args:
13-
t: current time (a scalar/int value)
14-
15-
s: binary spike vector
16-
17-
tols: current time-of-last-spike variable
18-
19-
Returns:
20-
updated tols variable
21-
"""
22-
_tols = (1. - s) * tols + (s * t)
23-
return _tols
24-
25-
@partial(jit, static_argnums=[3])
26-
def _sample_poisson(dkey, data, dt, fmax=63.75):
27-
"""
28-
Samples a Poisson spike train on-the-fly.
29-
30-
Args:
31-
dkey: JAX key to drive stochasticity/noise
32-
33-
data: sensory data (vector/matrix)
34-
35-
dt: integration time constant
36-
37-
fmax: maximum frequency (Hz)
38-
39-
Returns:
40-
binary spikes
41-
"""
42-
pspike = data * (dt/1000.) * fmax
43-
eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32)
44-
s_t = (eps < pspike).astype(jnp.float32)
45-
return s_t
468

479
class PoissonCell(JaxComponent):
4810
"""
49-
A Poisson cell that produces approximately Poisson-distributed spikes on-the-fly.
11+
A Poisson cell that produces approximately Poisson-distributed spikes
12+
on-the-fly.
5013
5114
| --- Cell Input Compartments: ---
5215
| inputs - input (takes in external signals)
@@ -61,49 +24,78 @@ class PoissonCell(JaxComponent):
6124
6225
n_units: number of cellular entities (neural population size)
6326
64-
max_freq: maximum frequency (in Hertz) of this Poisson spike train (must be > 0.)
27+
max_freq: maximum frequency (in Hertz) of this Poisson spike train (
28+
must be > 0.)
6529
"""
6630

6731
# Define Functions
68-
def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
32+
@deprecate_args(target_freq="max_freq")
33+
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
34+
**kwargs):
6935
super().__init__(name, **kwargs)
7036

7137
## Poisson meta-parameters
72-
self.max_freq = max_freq ## maximum frequency (in Hertz/Hz)
38+
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
7339

7440
## Layer Size Setup
7541
self.batch_size = batch_size
7642
self.n_units = n_units
7743

44+
_key, subkey = random.split(self.key.value, 2)
45+
self.key.set(_key)
7846
## Compartment setup
7947
restVals = jnp.zeros((self.batch_size, self.n_units))
80-
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
81-
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
82-
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
48+
self.inputs = Compartment(restVals,
49+
display_name="Input Stimulus") # input
50+
# compartment
51+
self.outputs = Compartment(restVals,
52+
display_name="Spikes") # output compartment
53+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
54+
units="ms") # time of last spike
55+
self.targets = Compartment(
56+
random.uniform(subkey, (self.batch_size, self.n_units), minval=0.,
57+
maxval=1.))
8358

8459
@staticmethod
85-
def _advance_state(t, dt, max_freq, key, inputs, tols):
86-
key, *subkeys = random.split(key, 2)
87-
outputs = _sample_poisson(subkeys[0], data=inputs, dt=dt, fmax=max_freq)
88-
tols = _update_times(t, outputs, tols)
89-
return outputs, tols, key
60+
def _advance_state(t, dt, target_freq, key, inputs, targets, tols):
61+
ms_per_second = 1000 # ms/s
62+
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
63+
ms_per_event = 1 / events_per_ms # ms/e
64+
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
65+
66+
cdf = scipy.special.gammaincc((t + dt) - tols,
67+
time_step_per_event/inputs)
68+
outputs = (targets < cdf).astype(jnp.float32)
69+
70+
key, subkey = random.split(key, 2)
71+
targets = (targets * (1 - outputs) + random.uniform(subkey,
72+
targets.shape) *
73+
outputs)
74+
75+
tols = tols * (1. - outputs) + t * outputs
76+
return outputs, tols, key, targets
9077

9178
@resolver(_advance_state)
92-
def advance_state(self, outputs, tols, key):
79+
def advance_state(self, outputs, tols, key, targets):
9380
self.outputs.set(outputs)
9481
self.tols.set(tols)
9582
self.key.set(key)
83+
self.targets.set(targets)
9684

9785
@staticmethod
98-
def _reset(batch_size, n_units):
86+
def _reset(batch_size, n_units, key):
9987
restVals = jnp.zeros((batch_size, n_units))
100-
return restVals, restVals, restVals
88+
key, subkey = random.split(key, 2)
89+
targets = random.uniform(subkey, (batch_size, n_units))
90+
return restVals, restVals, restVals, targets, key
10191

10292
@resolver(_reset)
103-
def reset(self, inputs, outputs, tols):
93+
def reset(self, inputs, outputs, tols, targets, key):
10494
self.inputs.set(inputs)
10595
self.outputs.set(outputs)
10696
self.tols.set(tols)
97+
self.key.set(key)
98+
self.targets.set(targets)
10799

108100
def save(self, directory, **kwargs):
109101
file_name = directory + "/" + self.name + ".npz"
@@ -115,36 +107,39 @@ def load(self, directory, **kwargs):
115107
self.key.set(data['key'])
116108

117109
@classmethod
118-
def help(cls): ## component help function
110+
def help(cls): ## component help function
119111
properties = {
120112
"cell_type": "PoissonCell - samples input to produce spikes, "
121-
"where dimension is a probability proportional to "
122-
"the dimension's magnitude/value/intensity and "
123-
"constrained by a maximum spike frequency (spikes follow "
113+
"where dimension is a probability proportional to "
114+
"the dimension's magnitude/value/intensity and "
115+
"constrained by a maximum spike frequency (spikes "
116+
"follow "
124117
"a Poisson distribution)"
125118
}
126119
compartment_props = {
127120
"inputs":
128121
{"inputs": "Takes in external input signal values"},
129122
"states":
130-
{"key": "JAX PRNG key"},
123+
{"key": "JAX PRNG key",
124+
"targets": "Target cdf for the Poisson distribution"},
131125
"outputs":
132126
{"tols": "Time-of-last-spike",
133127
"outputs": "Binary spike values emitted at time t"},
134128
}
135129
hyperparams = {
136130
"n_units": "Number of neuronal cells to model in this layer",
137131
"batch_size": "Batch size dimension of this component",
138-
"max_freq": "Maximum spike frequency of the train produced",
132+
"target_freq": "Maximum spike frequency of the train produced",
139133
}
140134
info = {cls.__name__: properties,
141135
"compartments": compartment_props,
142-
"dynamics": "~ Poisson(x; max_freq)",
136+
"dynamics": "~ Poisson(x; target_freq)",
143137
"hyperparameters": hyperparams}
144138
return info
145139

146140
def __repr__(self):
147-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
141+
comps = [varname for varname in dir(self) if
142+
Compartment.is_compartment(getattr(self, varname))]
148143
maxlen = max(len(c) for c in comps) + 5
149144
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
150145
for c in comps:
@@ -157,8 +152,10 @@ def __repr__(self):
157152
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
158153
return lines
159154

155+
160156
if __name__ == '__main__':
161157
from ngcsimlib.context import Context
158+
162159
with Context("Bar") as bar:
163160
X = PoissonCell("X", 9)
164161
print(X)

0 commit comments

Comments
 (0)