1
1
from ngclearn import resolver , Component , Compartment
2
2
from ngclearn .components .jaxComponent import JaxComponent
3
+ from jax import numpy as jnp , random , jit
3
4
from ngclearn .utils import tensorstats
4
- from jax import numpy as jnp , random , jit , scipy
5
5
from functools import partial
6
6
from ngcsimlib .deprecators import deprecate_args
7
7
from ngcsimlib .logger import info , warn
8
8
9
+ @jit
10
+ def _update_times (t , s , tols ):
11
+ """
12
+ Updates time-of-last-spike (tols) variable.
13
+
14
+ Args:
15
+ t: current time (a scalar/int value)
16
+
17
+ s: binary spike vector
18
+
19
+ tols: current time-of-last-spike variable
20
+
21
+ Returns:
22
+ updated tols variable
23
+ """
24
+ _tols = (1. - s ) * tols + (s * t )
25
+ return _tols
26
+
9
27
class PoissonCell (JaxComponent ):
10
28
"""
11
- A Poisson cell that produces approximately Poisson-distributed spikes
12
- on-the-fly .
29
+ A Poisson cell that samples a homogeneous Poisson process on-the-fly to
30
+ produce a spike train .
13
31
14
32
| --- Cell Input Compartments: ---
15
33
| inputs - input (takes in external signals)
@@ -24,45 +42,33 @@ class PoissonCell(JaxComponent):
24
42
25
43
n_units: number of cellular entities (neural population size)
26
44
27
- max_freq: maximum frequency (in Hertz) of this Poisson spike train (
28
- must be > 0.)
45
+ target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
29
46
"""
30
47
31
- # Define Functions
32
48
@deprecate_args (max_freq = "target_freq" )
33
- def __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 ,
34
- ** kwargs ):
49
+ def __init__ (self , name , n_units , target_freq = 0. , batch_size = 1 , ** kwargs ):
35
50
super ().__init__ (name , ** kwargs )
36
51
37
- ## Poisson meta-parameters
52
+ ## Constrained Bernoulli meta-parameters
38
53
self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
39
54
40
55
## Layer Size Setup
41
56
self .batch_size = batch_size
42
57
self .n_units = n_units
43
58
44
- _key , subkey = random .split (self .key .value , 2 )
45
- self .key .set (_key )
46
- ## Compartment setup
59
+ # Compartments (state of the cell, parameters, will be updated through stateless calls)
47
60
restVals = jnp .zeros ((self .batch_size , self .n_units ))
48
- self .inputs = Compartment (restVals ,
49
- display_name = "Input Stimulus" ) # input
50
- # compartment
51
- self .outputs = Compartment (restVals ,
52
- display_name = "Spikes" ) # output compartment
53
- self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
54
- units = "ms" ) # time of last spike
55
- self .targets = Compartment (
56
- random .uniform (subkey , (self .batch_size , self .n_units ), minval = 0. ,
57
- maxval = 1. ))
61
+ self .inputs = Compartment (restVals , display_name = "Input Stimulus" ) # input compartment
62
+ self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
63
+ self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
58
64
59
65
def validate (self , dt = None , ** validation_kwargs ):
60
66
valid = super ().validate (** validation_kwargs )
61
67
if dt is None :
62
68
warn (f"{ self .name } requires a validation kwarg of `dt`" )
63
69
return False
64
70
## check for unstable combinations of dt and target-frequency meta-params
65
- events_per_timestep = (dt / 1000. ) * self .target_freq ## compute scaled probability
71
+ events_per_timestep = (dt / 1000. ) * self .target_freq ## compute scaled probability
66
72
if events_per_timestep > 1. :
67
73
valid = False
68
74
warn (
@@ -74,54 +80,43 @@ def validate(self, dt=None, **validation_kwargs):
74
80
return valid
75
81
76
82
@staticmethod
77
- def _advance_state (t , dt , target_freq , key , inputs , targets , tols ):
78
- ms_per_second = 1000 # ms/s
79
- events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
80
- ms_per_event = 1 / events_per_ms # ms/e
81
- time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
82
-
83
- cdf = scipy .special .gammaincc ((t + dt ) - tols ,
84
- time_step_per_event / inputs )
85
- outputs = (targets < cdf ).astype (jnp .float32 )
86
-
87
- key , subkey = random .split (key , 2 )
88
- targets = (targets * (1 - outputs ) + random .uniform (subkey ,
89
- targets .shape ) *
90
- outputs )
91
-
92
- tols = tols * (1. - outputs ) + t * outputs
93
- return outputs , tols , key , targets
83
+ def _advance_state (t , dt , target_freq , key , inputs , tols ):
84
+ key , * subkeys = random .split (key , 2 )
85
+ pspike = inputs * (dt / 1000. ) * target_freq
86
+ eps = random .uniform (subkeys [0 ], inputs .shape , minval = 0. , maxval = 1. ,
87
+ dtype = jnp .float32 )
88
+ outputs = (eps < pspike ).astype (jnp .float32 )
89
+ tols = _update_times (t , outputs , tols )
90
+ return outputs , tols , key
94
91
95
92
@resolver (_advance_state )
96
- def advance_state (self , outputs , tols , key , targets ):
93
+ def advance_state (self , outputs , tols , key ):
97
94
self .outputs .set (outputs )
98
95
self .tols .set (tols )
99
96
self .key .set (key )
100
- self .targets .set (targets )
101
97
102
98
@staticmethod
103
- def _reset (batch_size , n_units , key ):
99
+ def _reset (batch_size , n_units ):
104
100
restVals = jnp .zeros ((batch_size , n_units ))
105
- key , subkey = random .split (key , 2 )
106
- targets = random .uniform (subkey , (batch_size , n_units ))
107
- return restVals , restVals , restVals , targets , key
101
+ return restVals , restVals , restVals
108
102
109
103
@resolver (_reset )
110
- def reset (self , inputs , outputs , tols , targets , key ):
104
+ def reset (self , inputs , outputs , tols ):
111
105
self .inputs .set (inputs )
112
- self .outputs .set (outputs )
106
+ self .outputs .set (outputs ) #None
113
107
self .tols .set (tols )
114
- self .key .set (key )
115
- self .targets .set (targets )
116
108
117
109
def save (self , directory , ** kwargs ):
110
+ target_freq = (self .target_freq if isinstance (self .target_freq , float )
111
+ else jnp .ones ([[self .target_freq ]]))
118
112
file_name = directory + "/" + self .name + ".npz"
119
- jnp .savez (file_name , key = self .key .value )
113
+ jnp .savez (file_name , key = self .key .value , target_freq = target_freq )
120
114
121
115
def load (self , directory , ** kwargs ):
122
116
file_name = directory + "/" + self .name + ".npz"
123
117
data = jnp .load (file_name )
124
118
self .key .set (data ['key' ])
119
+ self .target_freq = data ['target_freq' ]
125
120
126
121
@classmethod
127
122
def help (cls ): ## component help function
0 commit comments