Skip to content

Commit 05ea912

Browse files
committed
modded bernoulli-cell to include max-frequency constraint
1 parent 42167cb commit 05ea912

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, random, jit
44
from ngclearn.utils import tensorstats
5+
from functools import partial
56

67
@jit
78
def _update_times(t, s, tols):
@@ -37,9 +38,33 @@ def _sample_bernoulli(dkey, data):
3738
s_t = random.bernoulli(dkey, p=data).astype(jnp.float32)
3839
return s_t
3940

41+
@partial(jit, static_argnums=[3])
42+
def _sample_constrained_bernoulli(dkey, data, dt, fmax=63.75):
43+
"""
44+
Samples a Bernoulli spike train on-the-fly that is constrained to emit
45+
at a particular rate over a time window.
46+
47+
Args:
48+
dkey: JAX key to drive stochasticity/noise
49+
50+
data: sensory data (vector/matrix)
51+
52+
dt: integration time constant
53+
54+
fmax: maximum frequency (Hz)
55+
56+
Returns:
57+
binary spikes
58+
"""
59+
pspike = data * (dt/1000.) * fmax
60+
eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32)
61+
s_t = (eps < pspike).astype(jnp.float32)
62+
return s_t
63+
4064
class BernoulliCell(JaxComponent):
4165
"""
42-
A Bernoulli cell that produces Bernoulli-distributed spikes on-the-fly.
66+
A Bernoulli cell that produces variations of Bernoulli-distributed spikes
67+
on-the-fly (including constrained-rate trains).
4368
4469
| --- Cell Input Compartments: ---
4570
| inputs - input (takes in external signals)
@@ -53,12 +78,17 @@ class BernoulliCell(JaxComponent):
5378
name: the string name of this cell
5479
5580
n_units: number of cellular entities (neural population size)
81+
82+
max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
5683
"""
5784

5885
# Define Functions
59-
def __init__(self, name, n_units, batch_size=1, **kwargs):
86+
def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
6087
super().__init__(name, **kwargs)
6188

89+
## Constrained Bernoulli meta-parameters
90+
self.max_freq = max_freq ## maximum frequency (in Hertz/Hz)
91+
6292
## Layer Size Setup
6393
self.batch_size = batch_size
6494
self.n_units = n_units
@@ -70,11 +100,16 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
70100
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
71101

72102
@staticmethod
73-
def _advance_state(t, key, inputs, tols):
103+
def _advance_state(t, dt, max_freq, key, inputs, tols):
74104
key, *subkeys = random.split(key, 2)
75-
outputs = _sample_bernoulli(subkeys[0], data=inputs)
76-
timeOfLastSpike = _update_times(t, outputs, tols)
77-
return outputs, timeOfLastSpike, key
105+
if max_freq > 0.:
106+
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
107+
subkeys[0], data=inputs, dt=dt, fmax=max_freq
108+
)
109+
else:
110+
outputs = _sample_bernoulli(subkeys[0], data=inputs)
111+
tols = _update_times(t, outputs, tols)
112+
return outputs, tols, key
78113

79114
@resolver(_advance_state)
80115
def advance_state(self, outputs, tols, key):

0 commit comments

Comments
 (0)