Skip to content

Commit bf72094

Browse files
committed
moved back and cleaned up bernoulli and poisson cells
1 parent 9afaadf commit bf72094

File tree

2 files changed

+51
-123
lines changed

2 files changed

+51
-123
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

+5-72
Original file line numberDiff line numberDiff line change
@@ -24,49 +24,10 @@ def _update_times(t, s, tols):
2424
_tols = (1. - s) * tols + (s * t)
2525
return _tols
2626

27-
@jit
28-
def _sample_bernoulli(dkey, data):
29-
"""
30-
Samples a Bernoulli spike train on-the-fly
31-
32-
Args:
33-
dkey: JAX key to drive stochasticity/noise
34-
35-
data: sensory data (vector/matrix)
36-
37-
Returns:
38-
binary spikes
39-
"""
40-
s_t = random.bernoulli(dkey, p=data).astype(jnp.float32)
41-
return s_t
42-
43-
@partial(jit, static_argnums=[3])
44-
def _sample_constrained_bernoulli(dkey, data, dt, fmax=63.75):
45-
"""
46-
Samples a Bernoulli spike train on-the-fly that is constrained to emit
47-
at a particular rate over a time window.
48-
49-
Args:
50-
dkey: JAX key to drive stochasticity/noise
51-
52-
data: sensory data (vector/matrix)
53-
54-
dt: integration time constant
55-
56-
fmax: maximum frequency (Hz)
57-
58-
Returns:
59-
binary spikes
60-
"""
61-
pspike = data * (dt/1000.) * fmax
62-
eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32)
63-
s_t = (eps < pspike).astype(jnp.float32)
64-
return s_t
65-
6627
class BernoulliCell(JaxComponent):
6728
"""
68-
A Bernoulli cell that produces variations of Bernoulli-distributed spikes
69-
on-the-fly (including constrained-rate trains).
29+
A Bernoulli cell that produces spikes by sampling a Bernoulli distribution
30+
on-the-fly (to produce data-scaled Bernoulli spike trains).
7031
7132
| --- Cell Input Compartments: ---
7233
| inputs - input (takes in external signals)
@@ -80,17 +41,11 @@ class BernoulliCell(JaxComponent):
8041
name: the string name of this cell
8142
8243
n_units: number of cellular entities (neural population size)
83-
84-
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
8544
"""
8645

87-
@deprecate_args(max_freq="target_freq")
88-
def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
46+
def __init__(self, name, n_units, batch_size=1, **kwargs):
8947
super().__init__(name, **kwargs)
9048

91-
## Constrained Bernoulli meta-parameters
92-
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
93-
9449
## Layer Size Setup
9550
self.batch_size = batch_size
9651
self.n_units = n_units
@@ -101,32 +56,10 @@ def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
10156
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
10257
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
10358

104-
def validate(self, dt=None, **validation_kwargs):
105-
valid = super().validate(**validation_kwargs)
106-
if dt is None:
107-
warn(f"{self.name} requires a validation kwarg of `dt`")
108-
return False
109-
## check for unstable combinations of dt and target-frequency meta-params
110-
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
111-
if events_per_timestep > 1.:
112-
valid = False
113-
warn(
114-
f"{self.name} will be unable to make as many temporal events as "
115-
f"requested! ({events_per_timestep} events/timestep) Unstable "
116-
f"combination of dt = {dt} and target_freq = {self.target_freq} "
117-
f"being used!"
118-
)
119-
return valid
120-
12159
@staticmethod
122-
def _advance_state(t, dt, target_freq, key, inputs, tols):
60+
def _advance_state(t, key, inputs, tols):
12361
key, *subkeys = random.split(key, 2)
124-
if target_freq > 0.:
125-
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
126-
subkeys[0], data=inputs, dt=dt, fmax=target_freq
127-
)
128-
else:
129-
outputs = _sample_bernoulli(subkeys[0], data=inputs)
62+
outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
13063
tols = _update_times(t, outputs, tols)
13164
return outputs, tols, key
13265

ngclearn/components/input_encoders/poissonCell.py

+46-51
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,33 @@
11
from ngclearn import resolver, Component, Compartment
22
from ngclearn.components.jaxComponent import JaxComponent
3+
from jax import numpy as jnp, random, jit
34
from ngclearn.utils import tensorstats
4-
from jax import numpy as jnp, random, jit, scipy
55
from functools import partial
66
from ngcsimlib.deprecators import deprecate_args
77
from ngcsimlib.logger import info, warn
88

9+
@jit
10+
def _update_times(t, s, tols):
11+
"""
12+
Updates time-of-last-spike (tols) variable.
13+
14+
Args:
15+
t: current time (a scalar/int value)
16+
17+
s: binary spike vector
18+
19+
tols: current time-of-last-spike variable
20+
21+
Returns:
22+
updated tols variable
23+
"""
24+
_tols = (1. - s) * tols + (s * t)
25+
return _tols
26+
927
class PoissonCell(JaxComponent):
1028
"""
11-
A Poisson cell that produces approximately Poisson-distributed spikes
12-
on-the-fly.
29+
A Poisson cell that samples a homogeneous Poisson process on-the-fly to
30+
produce a spike train.
1331
1432
| --- Cell Input Compartments: ---
1533
| inputs - input (takes in external signals)
@@ -24,45 +42,33 @@ class PoissonCell(JaxComponent):
2442
2543
n_units: number of cellular entities (neural population size)
2644
27-
max_freq: maximum frequency (in Hertz) of this Poisson spike train (
28-
must be > 0.)
45+
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
2946
"""
3047

31-
# Define Functions
3248
@deprecate_args(max_freq="target_freq")
33-
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
34-
**kwargs):
49+
def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
3550
super().__init__(name, **kwargs)
3651

37-
## Poisson meta-parameters
52+
## Constrained Bernoulli meta-parameters
3853
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
3954

4055
## Layer Size Setup
4156
self.batch_size = batch_size
4257
self.n_units = n_units
4358

44-
_key, subkey = random.split(self.key.value, 2)
45-
self.key.set(_key)
46-
## Compartment setup
59+
# Compartments (state of the cell, parameters, will be updated through stateless calls)
4760
restVals = jnp.zeros((self.batch_size, self.n_units))
48-
self.inputs = Compartment(restVals,
49-
display_name="Input Stimulus") # input
50-
# compartment
51-
self.outputs = Compartment(restVals,
52-
display_name="Spikes") # output compartment
53-
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
54-
units="ms") # time of last spike
55-
self.targets = Compartment(
56-
random.uniform(subkey, (self.batch_size, self.n_units), minval=0.,
57-
maxval=1.))
61+
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
62+
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
63+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
5864

5965
def validate(self, dt=None, **validation_kwargs):
6066
valid = super().validate(**validation_kwargs)
6167
if dt is None:
6268
warn(f"{self.name} requires a validation kwarg of `dt`")
6369
return False
6470
## check for unstable combinations of dt and target-frequency meta-params
65-
events_per_timestep = (dt / 1000.) * self.target_freq ## compute scaled probability
71+
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
6672
if events_per_timestep > 1.:
6773
valid = False
6874
warn(
@@ -74,54 +80,43 @@ def validate(self, dt=None, **validation_kwargs):
7480
return valid
7581

7682
@staticmethod
77-
def _advance_state(t, dt, target_freq, key, inputs, targets, tols):
78-
ms_per_second = 1000 # ms/s
79-
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
80-
ms_per_event = 1 / events_per_ms # ms/e
81-
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
82-
83-
cdf = scipy.special.gammaincc((t + dt) - tols,
84-
time_step_per_event/inputs)
85-
outputs = (targets < cdf).astype(jnp.float32)
86-
87-
key, subkey = random.split(key, 2)
88-
targets = (targets * (1 - outputs) + random.uniform(subkey,
89-
targets.shape) *
90-
outputs)
91-
92-
tols = tols * (1. - outputs) + t * outputs
93-
return outputs, tols, key, targets
83+
def _advance_state(t, dt, target_freq, key, inputs, tols):
84+
key, *subkeys = random.split(key, 2)
85+
pspike = inputs * (dt / 1000.) * target_freq
86+
eps = random.uniform(subkeys[0], inputs.shape, minval=0., maxval=1.,
87+
dtype=jnp.float32)
88+
outputs = (eps < pspike).astype(jnp.float32)
89+
tols = _update_times(t, outputs, tols)
90+
return outputs, tols, key
9491

9592
@resolver(_advance_state)
96-
def advance_state(self, outputs, tols, key, targets):
93+
def advance_state(self, outputs, tols, key):
9794
self.outputs.set(outputs)
9895
self.tols.set(tols)
9996
self.key.set(key)
100-
self.targets.set(targets)
10197

10298
@staticmethod
103-
def _reset(batch_size, n_units, key):
99+
def _reset(batch_size, n_units):
104100
restVals = jnp.zeros((batch_size, n_units))
105-
key, subkey = random.split(key, 2)
106-
targets = random.uniform(subkey, (batch_size, n_units))
107-
return restVals, restVals, restVals, targets, key
101+
return restVals, restVals, restVals
108102

109103
@resolver(_reset)
110-
def reset(self, inputs, outputs, tols, targets, key):
104+
def reset(self, inputs, outputs, tols):
111105
self.inputs.set(inputs)
112-
self.outputs.set(outputs)
106+
self.outputs.set(outputs) #None
113107
self.tols.set(tols)
114-
self.key.set(key)
115-
self.targets.set(targets)
116108

117109
def save(self, directory, **kwargs):
110+
target_freq = (self.target_freq if isinstance(self.target_freq, float)
111+
else jnp.ones([[self.target_freq]]))
118112
file_name = directory + "/" + self.name + ".npz"
119-
jnp.savez(file_name, key=self.key.value)
113+
jnp.savez(file_name, key=self.key.value, target_freq=target_freq)
120114

121115
def load(self, directory, **kwargs):
122116
file_name = directory + "/" + self.name + ".npz"
123117
data = jnp.load(file_name)
124118
self.key.set(data['key'])
119+
self.target_freq = data['target_freq']
125120

126121
@classmethod
127122
def help(cls): ## component help function

0 commit comments

Comments
 (0)