11from  ngclearn  import  resolver , Component , Compartment 
22from  ngclearn .components .jaxComponent  import  JaxComponent 
3+ from  jax  import  numpy  as  jnp , random , jit 
34from  ngclearn .utils  import  tensorstats 
4- from  jax  import  numpy  as  jnp , random , jit , scipy 
55from  functools  import  partial 
66from  ngcsimlib .deprecators  import  deprecate_args 
77from  ngcsimlib .logger  import  info , warn 
88
9+ @jit  
10+ def  _update_times (t , s , tols ):
11+     """ 
12+     Updates time-of-last-spike (tols) variable. 
13+ 
14+     Args: 
15+         t: current time (a scalar/int value) 
16+ 
17+         s: binary spike vector 
18+ 
19+         tols: current time-of-last-spike variable 
20+ 
21+     Returns: 
22+         updated tols variable 
23+     """ 
24+     _tols  =  (1.  -  s ) *  tols  +  (s  *  t )
25+     return  _tols 
26+ 
927class  PoissonCell (JaxComponent ):
1028    """ 
11-     A Poisson cell that produces approximately  Poisson-distributed spikes  
12-     on-the-fly . 
29+     A Poisson cell that samples a homogeneous  Poisson process on-the-fly to  
30+     produce a spike train . 
1331
1432    | --- Cell Input Compartments: --- 
1533    | inputs - input (takes in external signals) 
@@ -24,45 +42,33 @@ class PoissonCell(JaxComponent):
2442
2543        n_units: number of cellular entities (neural population size) 
2644
27-         max_freq: maximum frequency (in Hertz) of this Poisson spike train ( 
28-         must be > 0.) 
45+         target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.) 
2946    """ 
3047
31-     # Define Functions 
3248    @deprecate_args (max_freq = "target_freq" ) 
33-     def  __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 ,
34-                  ** kwargs ):
49+     def  __init__ (self , name , n_units , target_freq = 0. , batch_size = 1 , ** kwargs ):
3550        super ().__init__ (name , ** kwargs )
3651
37-         ## Poisson  meta-parameters 
52+         ## Constrained Bernoulli  meta-parameters 
3853        self .target_freq  =  target_freq   ## maximum frequency (in Hertz/Hz) 
3954
4055        ## Layer Size Setup 
4156        self .batch_size  =  batch_size 
4257        self .n_units  =  n_units 
4358
44-         _key , subkey  =  random .split (self .key .value , 2 )
45-         self .key .set (_key )
46-         ## Compartment setup 
59+         # Compartments (state of the cell, parameters, will be updated through stateless calls) 
4760        restVals  =  jnp .zeros ((self .batch_size , self .n_units ))
48-         self .inputs  =  Compartment (restVals ,
49-                                   display_name = "Input Stimulus" )  # input 
50-         # compartment 
51-         self .outputs  =  Compartment (restVals ,
52-                                    display_name = "Spikes" )  # output compartment 
53-         self .tols  =  Compartment (restVals , display_name = "Time-of-Last-Spike" ,
54-                                 units = "ms" )  # time of last spike 
55-         self .targets  =  Compartment (
56-             random .uniform (subkey , (self .batch_size , self .n_units ), minval = 0. ,
57-                            maxval = 1. ))
61+         self .inputs  =  Compartment (restVals , display_name = "Input Stimulus" ) # input compartment 
62+         self .outputs  =  Compartment (restVals , display_name = "Spikes" ) # output compartment 
63+         self .tols  =  Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike 
5864
5965    def  validate (self , dt = None , ** validation_kwargs ):
6066        valid  =  super ().validate (** validation_kwargs )
6167        if  dt  is  None :
6268            warn (f"{ self .name }   requires a validation kwarg of `dt`" )
6369            return  False 
6470        ## check for unstable combinations of dt and target-frequency meta-params 
65-         events_per_timestep  =  (dt   /   1000. ) *  self .target_freq    ## compute scaled probability 
71+         events_per_timestep  =  (dt / 1000. ) *  self .target_freq  ## compute scaled probability 
6672        if  events_per_timestep  >  1. :
6773            valid  =  False 
6874            warn (
@@ -74,54 +80,43 @@ def validate(self, dt=None, **validation_kwargs):
7480        return  valid 
7581
7682    @staticmethod  
77-     def  _advance_state (t , dt , target_freq , key , inputs , targets , tols ):
78-         ms_per_second  =  1000   # ms/s 
79-         events_per_ms  =  target_freq  /  ms_per_second   # e/s s/ms -> e/ms 
80-         ms_per_event  =  1  /  events_per_ms   # ms/e 
81-         time_step_per_event  =  ms_per_event  /  dt   # ms/e * ts/ms -> ts / e 
82- 
83-         cdf  =  scipy .special .gammaincc ((t  +  dt ) -  tols ,
84-                                       time_step_per_event / inputs )
85-         outputs  =  (targets  <  cdf ).astype (jnp .float32 )
86- 
87-         key , subkey  =  random .split (key , 2 )
88-         targets  =  (targets  *  (1  -  outputs ) +  random .uniform (subkey ,
89-                                                            targets .shape ) * 
90-                    outputs )
91- 
92-         tols  =  tols  *  (1.  -  outputs ) +  t  *  outputs 
93-         return  outputs , tols , key , targets 
83+     def  _advance_state (t , dt , target_freq , key , inputs , tols ):
84+         key , * subkeys  =  random .split (key , 2 )
85+         pspike  =  inputs  *  (dt  /  1000. ) *  target_freq 
86+         eps  =  random .uniform (subkeys [0 ], inputs .shape , minval = 0. , maxval = 1. ,
87+                              dtype = jnp .float32 )
88+         outputs  =  (eps  <  pspike ).astype (jnp .float32 )
89+         tols  =  _update_times (t , outputs , tols )
90+         return  outputs , tols , key 
9491
9592    @resolver (_advance_state ) 
96-     def  advance_state (self , outputs , tols , key ,  targets ):
93+     def  advance_state (self , outputs , tols , key ):
9794        self .outputs .set (outputs )
9895        self .tols .set (tols )
9996        self .key .set (key )
100-         self .targets .set (targets )
10197
10298    @staticmethod  
103-     def  _reset (batch_size , n_units ,  key ):
99+     def  _reset (batch_size , n_units ):
104100        restVals  =  jnp .zeros ((batch_size , n_units ))
105-         key , subkey  =  random .split (key , 2 )
106-         targets  =  random .uniform (subkey , (batch_size , n_units ))
107-         return  restVals , restVals , restVals , targets , key 
101+         return  restVals , restVals , restVals 
108102
109103    @resolver (_reset ) 
110-     def  reset (self , inputs , outputs , tols ,  targets ,  key ):
104+     def  reset (self , inputs , outputs , tols ):
111105        self .inputs .set (inputs )
112-         self .outputs .set (outputs )
106+         self .outputs .set (outputs )  #None 
113107        self .tols .set (tols )
114-         self .key .set (key )
115-         self .targets .set (targets )
116108
117109    def  save (self , directory , ** kwargs ):
110+         target_freq  =  (self .target_freq  if  isinstance (self .target_freq , float )
111+                        else  jnp .ones ([[self .target_freq ]]))
118112        file_name  =  directory  +  "/"  +  self .name  +  ".npz" 
119-         jnp .savez (file_name , key = self .key .value )
113+         jnp .savez (file_name , key = self .key .value ,  target_freq = target_freq )
120114
121115    def  load (self , directory , ** kwargs ):
122116        file_name  =  directory  +  "/"  +  self .name  +  ".npz" 
123117        data  =  jnp .load (file_name )
124118        self .key .set (data ['key' ])
119+         self .target_freq  =  data ['target_freq' ]
125120
126121    @classmethod  
127122    def  help (cls ):  ## component help function 
0 commit comments