11from ngclearn import resolver , Component , Compartment
22from ngclearn .components .jaxComponent import JaxComponent
3+ from jax import numpy as jnp , random , jit
34from ngclearn .utils import tensorstats
4- from jax import numpy as jnp , random , jit , scipy
55from functools import partial
66from ngcsimlib .deprecators import deprecate_args
77from ngcsimlib .logger import info , warn
88
9+ @jit
10+ def _update_times (t , s , tols ):
11+ """
12+ Updates time-of-last-spike (tols) variable.
13+
14+ Args:
15+ t: current time (a scalar/int value)
16+
17+ s: binary spike vector
18+
19+ tols: current time-of-last-spike variable
20+
21+ Returns:
22+ updated tols variable
23+ """
24+ _tols = (1. - s ) * tols + (s * t )
25+ return _tols
26+
927class PoissonCell (JaxComponent ):
1028 """
11- A Poisson cell that produces approximately Poisson-distributed spikes
12- on-the-fly .
29+ A Poisson cell that samples a homogeneous Poisson process on-the-fly to
30+ produce a spike train .
1331
1432 | --- Cell Input Compartments: ---
1533 | inputs - input (takes in external signals)
@@ -24,45 +42,33 @@ class PoissonCell(JaxComponent):
2442
2543 n_units: number of cellular entities (neural population size)
2644
27- max_freq: maximum frequency (in Hertz) of this Poisson spike train (
28- must be > 0.)
45+ target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
2946 """
3047
31- # Define Functions
3248 @deprecate_args (max_freq = "target_freq" )
33- def __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 ,
34- ** kwargs ):
49+ def __init__ (self , name , n_units , target_freq = 0. , batch_size = 1 , ** kwargs ):
3550 super ().__init__ (name , ** kwargs )
3651
37- ## Poisson meta-parameters
52+ ## Constrained Bernoulli meta-parameters
3853 self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
3954
4055 ## Layer Size Setup
4156 self .batch_size = batch_size
4257 self .n_units = n_units
4358
44- _key , subkey = random .split (self .key .value , 2 )
45- self .key .set (_key )
46- ## Compartment setup
59+ # Compartments (state of the cell, parameters, will be updated through stateless calls)
4760 restVals = jnp .zeros ((self .batch_size , self .n_units ))
48- self .inputs = Compartment (restVals ,
49- display_name = "Input Stimulus" ) # input
50- # compartment
51- self .outputs = Compartment (restVals ,
52- display_name = "Spikes" ) # output compartment
53- self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
54- units = "ms" ) # time of last spike
55- self .targets = Compartment (
56- random .uniform (subkey , (self .batch_size , self .n_units ), minval = 0. ,
57- maxval = 1. ))
61+ self .inputs = Compartment (restVals , display_name = "Input Stimulus" ) # input compartment
62+ self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
63+ self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
5864
5965 def validate (self , dt = None , ** validation_kwargs ):
6066 valid = super ().validate (** validation_kwargs )
6167 if dt is None :
6268 warn (f"{ self .name } requires a validation kwarg of `dt`" )
6369 return False
6470 ## check for unstable combinations of dt and target-frequency meta-params
65- events_per_timestep = (dt / 1000. ) * self .target_freq ## compute scaled probability
71+ events_per_timestep = (dt / 1000. ) * self .target_freq ## compute scaled probability
6672 if events_per_timestep > 1. :
6773 valid = False
6874 warn (
@@ -74,54 +80,43 @@ def validate(self, dt=None, **validation_kwargs):
7480 return valid
7581
7682 @staticmethod
77- def _advance_state (t , dt , target_freq , key , inputs , targets , tols ):
78- ms_per_second = 1000 # ms/s
79- events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
80- ms_per_event = 1 / events_per_ms # ms/e
81- time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
82-
83- cdf = scipy .special .gammaincc ((t + dt ) - tols ,
84- time_step_per_event / inputs )
85- outputs = (targets < cdf ).astype (jnp .float32 )
86-
87- key , subkey = random .split (key , 2 )
88- targets = (targets * (1 - outputs ) + random .uniform (subkey ,
89- targets .shape ) *
90- outputs )
91-
92- tols = tols * (1. - outputs ) + t * outputs
93- return outputs , tols , key , targets
83+ def _advance_state (t , dt , target_freq , key , inputs , tols ):
84+ key , * subkeys = random .split (key , 2 )
85+ pspike = inputs * (dt / 1000. ) * target_freq
86+ eps = random .uniform (subkeys [0 ], inputs .shape , minval = 0. , maxval = 1. ,
87+ dtype = jnp .float32 )
88+ outputs = (eps < pspike ).astype (jnp .float32 )
89+ tols = _update_times (t , outputs , tols )
90+ return outputs , tols , key
9491
9592 @resolver (_advance_state )
96- def advance_state (self , outputs , tols , key , targets ):
93+ def advance_state (self , outputs , tols , key ):
9794 self .outputs .set (outputs )
9895 self .tols .set (tols )
9996 self .key .set (key )
100- self .targets .set (targets )
10197
10298 @staticmethod
103- def _reset (batch_size , n_units , key ):
99+ def _reset (batch_size , n_units ):
104100 restVals = jnp .zeros ((batch_size , n_units ))
105- key , subkey = random .split (key , 2 )
106- targets = random .uniform (subkey , (batch_size , n_units ))
107- return restVals , restVals , restVals , targets , key
101+ return restVals , restVals , restVals
108102
109103 @resolver (_reset )
110- def reset (self , inputs , outputs , tols , targets , key ):
104+ def reset (self , inputs , outputs , tols ):
111105 self .inputs .set (inputs )
112- self .outputs .set (outputs )
106+ self .outputs .set (outputs ) #None
113107 self .tols .set (tols )
114- self .key .set (key )
115- self .targets .set (targets )
116108
117109 def save (self , directory , ** kwargs ):
110+ target_freq = (self .target_freq if isinstance (self .target_freq , float )
111+ else jnp .ones ([[self .target_freq ]]))
118112 file_name = directory + "/" + self .name + ".npz"
119- jnp .savez (file_name , key = self .key .value )
113+ jnp .savez (file_name , key = self .key .value , target_freq = target_freq )
120114
121115 def load (self , directory , ** kwargs ):
122116 file_name = directory + "/" + self .name + ".npz"
123117 data = jnp .load (file_name )
124118 self .key .set (data ['key' ])
119+ self .target_freq = data ['target_freq' ]
125120
126121 @classmethod
127122 def help (cls ): ## component help function
0 commit comments