11from ngclearn import resolver , Component , Compartment
22from ngclearn .components .jaxComponent import JaxComponent
33from ngclearn .utils import tensorstats
4- from jax import numpy as jnp , random , jit
4+ from jax import numpy as jnp , random , jit , scipy
55from functools import partial
6+ from ngcsimlib .deprecators import deprecate_args
67
7- @jit
8- def _update_times (t , s , tols ):
9- """
10- Updates time-of-last-spike (tols) variable.
11-
12- Args:
13- t: current time (a scalar/int value)
14-
15- s: binary spike vector
16-
17- tols: current time-of-last-spike variable
18-
19- Returns:
20- updated tols variable
21- """
22- _tols = (1. - s ) * tols + (s * t )
23- return _tols
24-
25- @partial (jit , static_argnums = [3 ])
26- def _sample_poisson (dkey , data , dt , fmax = 63.75 ):
27- """
28- Samples a Poisson spike train on-the-fly.
29-
30- Args:
31- dkey: JAX key to drive stochasticity/noise
32-
33- data: sensory data (vector/matrix)
34-
35- dt: integration time constant
36-
37- fmax: maximum frequency (Hz)
38-
39- Returns:
40- binary spikes
41- """
42- pspike = data * (dt / 1000. ) * fmax
43- eps = random .uniform (dkey , data .shape , minval = 0. , maxval = 1. , dtype = jnp .float32 )
44- s_t = (eps < pspike ).astype (jnp .float32 )
45- return s_t
468
479class PoissonCell (JaxComponent ):
4810 """
49- A Poisson cell that produces approximately Poisson-distributed spikes on-the-fly.
11+ A Poisson cell that produces approximately Poisson-distributed spikes
12+ on-the-fly.
5013
5114 | --- Cell Input Compartments: ---
5215 | inputs - input (takes in external signals)
@@ -61,49 +24,78 @@ class PoissonCell(JaxComponent):
6124
6225 n_units: number of cellular entities (neural population size)
6326
64- max_freq: maximum frequency (in Hertz) of this Poisson spike train (must be > 0.)
27+ max_freq: maximum frequency (in Hertz) of this Poisson spike train (
28+ must be > 0.)
6529 """
6630
6731 # Define Functions
68- def __init__ (self , name , n_units , max_freq = 63.75 , batch_size = 1 , ** kwargs ):
32+ @deprecate_args (target_freq = "max_freq" )
33+ def __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 ,
34+ ** kwargs ):
6935 super ().__init__ (name , ** kwargs )
7036
7137 ## Poisson meta-parameters
72- self .max_freq = max_freq ## maximum frequency (in Hertz/Hz)
38+ self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
7339
7440 ## Layer Size Setup
7541 self .batch_size = batch_size
7642 self .n_units = n_units
7743
44+ _key , subkey = random .split (self .key .value , 2 )
45+ self .key .set (_key )
7846 ## Compartment setup
7947 restVals = jnp .zeros ((self .batch_size , self .n_units ))
80- self .inputs = Compartment (restVals , display_name = "Input Stimulus" ) # input compartment
81- self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
82- self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
48+ self .inputs = Compartment (restVals ,
49+ display_name = "Input Stimulus" ) # input
50+ # compartment
51+ self .outputs = Compartment (restVals ,
52+ display_name = "Spikes" ) # output compartment
53+ self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
54+ units = "ms" ) # time of last spike
55+ self .targets = Compartment (
56+ random .uniform (subkey , (self .batch_size , self .n_units ), minval = 0. ,
57+ maxval = 1. ))
8358
8459 @staticmethod
85- def _advance_state (t , dt , max_freq , key , inputs , tols ):
86- key , * subkeys = random .split (key , 2 )
87- outputs = _sample_poisson (subkeys [0 ], data = inputs , dt = dt , fmax = max_freq )
88- tols = _update_times (t , outputs , tols )
89- return outputs , tols , key
60+ def _advance_state (t , dt , target_freq , key , inputs , targets , tols ):
61+ ms_per_second = 1000 # ms/s
62+ events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
63+ ms_per_event = 1 / events_per_ms # ms/e
64+ time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
65+
66+ cdf = scipy .special .gammaincc ((t + dt ) - tols ,
67+ time_step_per_event / inputs )
68+ outputs = (targets < cdf ).astype (jnp .float32 )
69+
70+ key , subkey = random .split (key , 2 )
71+ targets = (targets * (1 - outputs ) + random .uniform (subkey ,
72+ targets .shape ) *
73+ outputs )
74+
75+ tols = tols * (1. - outputs ) + t * outputs
76+ return outputs , tols , key , targets
9077
9178 @resolver (_advance_state )
92- def advance_state (self , outputs , tols , key ):
79+ def advance_state (self , outputs , tols , key , targets ):
9380 self .outputs .set (outputs )
9481 self .tols .set (tols )
9582 self .key .set (key )
83+ self .targets .set (targets )
9684
9785 @staticmethod
98- def _reset (batch_size , n_units ):
86+ def _reset (batch_size , n_units , key ):
9987 restVals = jnp .zeros ((batch_size , n_units ))
100- return restVals , restVals , restVals
88+ key , subkey = random .split (key , 2 )
89+ targets = random .uniform (subkey , (batch_size , n_units ))
90+ return restVals , restVals , restVals , targets , key
10191
10292 @resolver (_reset )
103- def reset (self , inputs , outputs , tols ):
93+ def reset (self , inputs , outputs , tols , targets , key ):
10494 self .inputs .set (inputs )
10595 self .outputs .set (outputs )
10696 self .tols .set (tols )
97+ self .key .set (key )
98+ self .targets .set (targets )
10799
108100 def save (self , directory , ** kwargs ):
109101 file_name = directory + "/" + self .name + ".npz"
@@ -115,36 +107,39 @@ def load(self, directory, **kwargs):
115107 self .key .set (data ['key' ])
116108
117109 @classmethod
118- def help (cls ): ## component help function
110+ def help (cls ): ## component help function
119111 properties = {
120112 "cell_type" : "PoissonCell - samples input to produce spikes, "
121- "where dimension is a probability proportional to "
122- "the dimension's magnitude/value/intensity and "
123- "constrained by a maximum spike frequency (spikes follow "
113+ "where dimension is a probability proportional to "
114+ "the dimension's magnitude/value/intensity and "
115+ "constrained by a maximum spike frequency (spikes "
116+ "follow "
124117 "a Poisson distribution)"
125118 }
126119 compartment_props = {
127120 "inputs" :
128121 {"inputs" : "Takes in external input signal values" },
129122 "states" :
130- {"key" : "JAX PRNG key" },
123+ {"key" : "JAX PRNG key" ,
124+ "targets" : "Target cdf for the Poisson distribution" },
131125 "outputs" :
132126 {"tols" : "Time-of-last-spike" ,
133127 "outputs" : "Binary spike values emitted at time t" },
134128 }
135129 hyperparams = {
136130 "n_units" : "Number of neuronal cells to model in this layer" ,
137131 "batch_size" : "Batch size dimension of this component" ,
138- "max_freq " : "Maximum spike frequency of the train produced" ,
132+ "target_freq " : "Maximum spike frequency of the train produced" ,
139133 }
140134 info = {cls .__name__ : properties ,
141135 "compartments" : compartment_props ,
142- "dynamics" : "~ Poisson(x; max_freq )" ,
136+ "dynamics" : "~ Poisson(x; target_freq )" ,
143137 "hyperparameters" : hyperparams }
144138 return info
145139
146140 def __repr__ (self ):
147- comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
141+ comps = [varname for varname in dir (self ) if
142+ Compartment .is_compartment (getattr (self , varname ))]
148143 maxlen = max (len (c ) for c in comps ) + 5
149144 lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
150145 for c in comps :
@@ -157,8 +152,10 @@ def __repr__(self):
157152 lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
158153 return lines
159154
155+
160156if __name__ == '__main__' :
161157 from ngcsimlib .context import Context
158+
162159 with Context ("Bar" ) as bar :
163160 X = PoissonCell ("X" , 9 )
164161 print (X )
0 commit comments