1
1
from ngclearn import resolver , Component , Compartment
2
2
from ngclearn .components .jaxComponent import JaxComponent
3
3
from ngclearn .utils import tensorstats
4
- from jax import numpy as jnp , random , jit
4
+ from jax import numpy as jnp , random , jit , scipy
5
5
from functools import partial
6
+ from ngcsimlib .deprecators import deprecate_args
6
7
7
- @jit
8
- def _update_times (t , s , tols ):
9
- """
10
- Updates time-of-last-spike (tols) variable.
11
-
12
- Args:
13
- t: current time (a scalar/int value)
14
-
15
- s: binary spike vector
16
-
17
- tols: current time-of-last-spike variable
18
-
19
- Returns:
20
- updated tols variable
21
- """
22
- _tols = (1. - s ) * tols + (s * t )
23
- return _tols
24
-
25
- @partial (jit , static_argnums = [3 ])
26
- def _sample_poisson (dkey , data , dt , fmax = 63.75 ):
27
- """
28
- Samples a Poisson spike train on-the-fly.
29
-
30
- Args:
31
- dkey: JAX key to drive stochasticity/noise
32
-
33
- data: sensory data (vector/matrix)
34
-
35
- dt: integration time constant
36
-
37
- fmax: maximum frequency (Hz)
38
-
39
- Returns:
40
- binary spikes
41
- """
42
- pspike = data * (dt / 1000. ) * fmax
43
- eps = random .uniform (dkey , data .shape , minval = 0. , maxval = 1. , dtype = jnp .float32 )
44
- s_t = (eps < pspike ).astype (jnp .float32 )
45
- return s_t
46
8
47
9
class PoissonCell (JaxComponent ):
48
10
"""
49
- A Poisson cell that produces approximately Poisson-distributed spikes on-the-fly.
11
+ A Poisson cell that produces approximately Poisson-distributed spikes
12
+ on-the-fly.
50
13
51
14
| --- Cell Input Compartments: ---
52
15
| inputs - input (takes in external signals)
@@ -61,49 +24,78 @@ class PoissonCell(JaxComponent):
61
24
62
25
n_units: number of cellular entities (neural population size)
63
26
64
- max_freq: maximum frequency (in Hertz) of this Poisson spike train (must be > 0.)
27
+ max_freq: maximum frequency (in Hertz) of this Poisson spike train (
28
+ must be > 0.)
65
29
"""
66
30
67
31
# Define Functions
68
- def __init__ (self , name , n_units , max_freq = 63.75 , batch_size = 1 , ** kwargs ):
32
+ @deprecate_args (target_freq = "max_freq" )
33
+ def __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 ,
34
+ ** kwargs ):
69
35
super ().__init__ (name , ** kwargs )
70
36
71
37
## Poisson meta-parameters
72
- self .max_freq = max_freq ## maximum frequency (in Hertz/Hz)
38
+ self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
73
39
74
40
## Layer Size Setup
75
41
self .batch_size = batch_size
76
42
self .n_units = n_units
77
43
44
+ _key , subkey = random .split (self .key .value , 2 )
45
+ self .key .set (_key )
78
46
## Compartment setup
79
47
restVals = jnp .zeros ((self .batch_size , self .n_units ))
80
- self .inputs = Compartment (restVals , display_name = "Input Stimulus" ) # input compartment
81
- self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
82
- self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
48
+ self .inputs = Compartment (restVals ,
49
+ display_name = "Input Stimulus" ) # input
50
+ # compartment
51
+ self .outputs = Compartment (restVals ,
52
+ display_name = "Spikes" ) # output compartment
53
+ self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
54
+ units = "ms" ) # time of last spike
55
+ self .targets = Compartment (
56
+ random .uniform (subkey , (self .batch_size , self .n_units ), minval = 0. ,
57
+ maxval = 1. ))
83
58
84
59
@staticmethod
85
- def _advance_state (t , dt , max_freq , key , inputs , tols ):
86
- key , * subkeys = random .split (key , 2 )
87
- outputs = _sample_poisson (subkeys [0 ], data = inputs , dt = dt , fmax = max_freq )
88
- tols = _update_times (t , outputs , tols )
89
- return outputs , tols , key
60
+ def _advance_state (t , dt , target_freq , key , inputs , targets , tols ):
61
+ ms_per_second = 1000 # ms/s
62
+ events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
63
+ ms_per_event = 1 / events_per_ms # ms/e
64
+ time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
65
+
66
+ cdf = scipy .special .gammaincc ((t + dt ) - tols ,
67
+ time_step_per_event / inputs )
68
+ outputs = (targets < cdf ).astype (jnp .float32 )
69
+
70
+ key , subkey = random .split (key , 2 )
71
+ targets = (targets * (1 - outputs ) + random .uniform (subkey ,
72
+ targets .shape ) *
73
+ outputs )
74
+
75
+ tols = tols * (1. - outputs ) + t * outputs
76
+ return outputs , tols , key , targets
90
77
91
78
@resolver (_advance_state )
92
- def advance_state (self , outputs , tols , key ):
79
+ def advance_state (self , outputs , tols , key , targets ):
93
80
self .outputs .set (outputs )
94
81
self .tols .set (tols )
95
82
self .key .set (key )
83
+ self .targets .set (targets )
96
84
97
85
@staticmethod
98
- def _reset (batch_size , n_units ):
86
+ def _reset (batch_size , n_units , key ):
99
87
restVals = jnp .zeros ((batch_size , n_units ))
100
- return restVals , restVals , restVals
88
+ key , subkey = random .split (key , 2 )
89
+ targets = random .uniform (subkey , (batch_size , n_units ))
90
+ return restVals , restVals , restVals , targets , key
101
91
102
92
@resolver (_reset )
103
- def reset (self , inputs , outputs , tols ):
93
+ def reset (self , inputs , outputs , tols , targets , key ):
104
94
self .inputs .set (inputs )
105
95
self .outputs .set (outputs )
106
96
self .tols .set (tols )
97
+ self .key .set (key )
98
+ self .targets .set (targets )
107
99
108
100
def save (self , directory , ** kwargs ):
109
101
file_name = directory + "/" + self .name + ".npz"
@@ -115,36 +107,39 @@ def load(self, directory, **kwargs):
115
107
self .key .set (data ['key' ])
116
108
117
109
@classmethod
118
- def help (cls ): ## component help function
110
+ def help (cls ): ## component help function
119
111
properties = {
120
112
"cell_type" : "PoissonCell - samples input to produce spikes, "
121
- "where dimension is a probability proportional to "
122
- "the dimension's magnitude/value/intensity and "
123
- "constrained by a maximum spike frequency (spikes follow "
113
+ "where dimension is a probability proportional to "
114
+ "the dimension's magnitude/value/intensity and "
115
+ "constrained by a maximum spike frequency (spikes "
116
+ "follow "
124
117
"a Poisson distribution)"
125
118
}
126
119
compartment_props = {
127
120
"inputs" :
128
121
{"inputs" : "Takes in external input signal values" },
129
122
"states" :
130
- {"key" : "JAX PRNG key" },
123
+ {"key" : "JAX PRNG key" ,
124
+ "targets" : "Target cdf for the Poisson distribution" },
131
125
"outputs" :
132
126
{"tols" : "Time-of-last-spike" ,
133
127
"outputs" : "Binary spike values emitted at time t" },
134
128
}
135
129
hyperparams = {
136
130
"n_units" : "Number of neuronal cells to model in this layer" ,
137
131
"batch_size" : "Batch size dimension of this component" ,
138
- "max_freq " : "Maximum spike frequency of the train produced" ,
132
+ "target_freq " : "Maximum spike frequency of the train produced" ,
139
133
}
140
134
info = {cls .__name__ : properties ,
141
135
"compartments" : compartment_props ,
142
- "dynamics" : "~ Poisson(x; max_freq )" ,
136
+ "dynamics" : "~ Poisson(x; target_freq )" ,
143
137
"hyperparameters" : hyperparams }
144
138
return info
145
139
146
140
def __repr__ (self ):
147
- comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
141
+ comps = [varname for varname in dir (self ) if
142
+ Compartment .is_compartment (getattr (self , varname ))]
148
143
maxlen = max (len (c ) for c in comps ) + 5
149
144
lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
150
145
for c in comps :
@@ -157,8 +152,10 @@ def __repr__(self):
157
152
lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
158
153
return lines
159
154
155
+
160
156
if __name__ == '__main__' :
161
157
from ngcsimlib .context import Context
158
+
162
159
with Context ("Bar" ) as bar :
163
160
X = PoissonCell ("X" , 9 )
164
161
print (X )
0 commit comments