2
2
from ngclearn .components .jaxComponent import JaxComponent
3
3
from jax import numpy as jnp , random , jit
4
4
from ngclearn .utils import tensorstats
5
+ from functools import partial
5
6
6
7
@jit
7
8
def _update_times (t , s , tols ):
@@ -37,9 +38,33 @@ def _sample_bernoulli(dkey, data):
37
38
s_t = random .bernoulli (dkey , p = data ).astype (jnp .float32 )
38
39
return s_t
39
40
41
+ @partial (jit , static_argnums = [3 ])
42
+ def _sample_constrained_bernoulli (dkey , data , dt , fmax = 63.75 ):
43
+ """
44
+ Samples a Bernoulli spike train on-the-fly that is constrained to emit
45
+ at a particular rate over a time window.
46
+
47
+ Args:
48
+ dkey: JAX key to drive stochasticity/noise
49
+
50
+ data: sensory data (vector/matrix)
51
+
52
+ dt: integration time constant
53
+
54
+ fmax: maximum frequency (Hz)
55
+
56
+ Returns:
57
+ binary spikes
58
+ """
59
+ pspike = data * (dt / 1000. ) * fmax
60
+ eps = random .uniform (dkey , data .shape , minval = 0. , maxval = 1. , dtype = jnp .float32 )
61
+ s_t = (eps < pspike ).astype (jnp .float32 )
62
+ return s_t
63
+
40
64
class BernoulliCell (JaxComponent ):
41
65
"""
42
- A Bernoulli cell that produces Bernoulli-distributed spikes on-the-fly.
66
+ A Bernoulli cell that produces variations of Bernoulli-distributed spikes
67
+ on-the-fly (including constrained-rate trains).
43
68
44
69
| --- Cell Input Compartments: ---
45
70
| inputs - input (takes in external signals)
@@ -53,12 +78,17 @@ class BernoulliCell(JaxComponent):
53
78
name: the string name of this cell
54
79
55
80
n_units: number of cellular entities (neural population size)
81
+
82
+ max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
56
83
"""
57
84
58
85
# Define Functions
59
- def __init__ (self , name , n_units , batch_size = 1 , ** kwargs ):
86
+ def __init__ (self , name , n_units , max_freq = 63.75 , batch_size = 1 , ** kwargs ):
60
87
super ().__init__ (name , ** kwargs )
61
88
89
+ ## Constrained Bernoulli meta-parameters
90
+ self .max_freq = max_freq ## maximum frequency (in Hertz/Hz)
91
+
62
92
## Layer Size Setup
63
93
self .batch_size = batch_size
64
94
self .n_units = n_units
@@ -70,11 +100,16 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
70
100
self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
71
101
72
102
@staticmethod
73
- def _advance_state (t , key , inputs , tols ):
103
+ def _advance_state (t , dt , max_freq , key , inputs , tols ):
74
104
key , * subkeys = random .split (key , 2 )
75
- outputs = _sample_bernoulli (subkeys [0 ], data = inputs )
76
- timeOfLastSpike = _update_times (t , outputs , tols )
77
- return outputs , timeOfLastSpike , key
105
+ if max_freq > 0. :
106
+ outputs = _sample_constrained_bernoulli ( ## sample Bernoulli w/ target rate
107
+ subkeys [0 ], data = inputs , dt = dt , fmax = max_freq
108
+ )
109
+ else :
110
+ outputs = _sample_bernoulli (subkeys [0 ], data = inputs )
111
+ tols = _update_times (t , outputs , tols )
112
+ return outputs , tols , key
78
113
79
114
@resolver (_advance_state )
80
115
def advance_state (self , outputs , tols , key ):
0 commit comments