Skip to content

Commit

Permalink
integrated phasor-cell, minor cleanup of latency
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Aug 6, 2024
1 parent bf06510 commit 77f347f
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 2 deletions.
1 change: 1 addition & 0 deletions ngclearn/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .input_encoders.bernoulliCell import BernoulliCell
from .input_encoders.poissonCell import PoissonCell
from .input_encoders.latencyCell import LatencyCell
from .input_encoders.phasorCell import PhasorCell
## point to synapse component types
from .synapses.denseSynapse import DenseSynapse
from .synapses.staticSynapse import StaticSynapse
Expand Down
1 change: 1 addition & 0 deletions ngclearn/components/input_encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bernoulliCell import BernoulliCell
from .poissonCell import PoissonCell
from .latencyCell import LatencyCell
from .phasorCell import PhasorCell
4 changes: 2 additions & 2 deletions ngclearn/components/input_encoders/latencyCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1.,
projected spike times
"""
_tau = tau
if normalize == True:
if normalize:
_tau = num_steps - 1. - first_spk_t ## linear normalization
#torch.clamp_max((-tau * (data - 1)), -tau * (threshold - 1))
stimes = -_tau * (data - 1.) ## calc raw latency code values
Expand Down Expand Up @@ -85,7 +85,7 @@ def _calc_spike_times_nonlinear(data, tau, thr, first_spk_t, eps=1e-7,
stimes = jnp.log(_data / (_data - thr)) * tau ## calc spike times
stimes = stimes + first_spk_t

if normalize == True:
if normalize:
term1 = (stimes - first_spk_t)
term2 = (num_steps - first_spk_t - 1.)
term3 = jnp.max(stimes - first_spk_t)
Expand Down
179 changes: 179 additions & 0 deletions ngclearn/components/input_encoders/phasorCell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from ngclearn import resolver, Compartment
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils import tensorstats
from jax import numpy as jnp, random
from ngcsimlib.logger import warn

class PhasorCell(JaxComponent):
"""
A phasor cell that emits a pulse at a regular interval.
| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
| --- Cell State Compartments: ---
| key - JAX PRNG key
| --- Cell Output Compartments: ---
| outputs - output
| tols - time-of-last-spike
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
target_freq: maximum frequency (in Hertz) of this spike train
(must be > 0.)
"""

# Define Functions
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
**kwargs):
super().__init__(name, **kwargs)

## Phasor meta-parameters
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)

## Layer Size Setup
self.batch_size = batch_size
self.n_units = n_units
_key, subkey = random.split(self.key.value, 2)
self.key.set(_key)
## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.inputs = Compartment(restVals,
display_name="Input Stimulus") # input
# compartment
self.outputs = Compartment(restVals,
display_name="Spikes") # output compartment
self.tols = Compartment(initial_value=restVals,
display_name="Time-of-Last-Spike", units="ms") # time of last spike
self.angles = Compartment(restVals, display_name="Angles", units="deg")
# self.base_scale = random.uniform(subkey, self.angles.value.shape,
# minval=0.75, maxval=1.25)
# self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1)
# alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
# beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq

self.base_scale = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq

def validate(self, dt=None, **validation_kwargs):
valid = super().validate(**validation_kwargs)
if dt is None:
warn(f"{self.name} requires a validation kwarg of `dt`")
return False
## check for unstable combinations of dt and target-frequency
# meta-params
events_per_timestep = (
dt / 1000.) * self.target_freq ##
# compute scaled probability
if events_per_timestep > 1.:
valid = False
warn(
f"{self.name} will be unable to make as many temporal events "
f"as "
f"requested! ({events_per_timestep} events/timestep) Unstable "
f"combination of dt = {dt} and target_freq = "
f"{self.target_freq} "
f"being used!"
)
return valid

@staticmethod
def _advance_state(t, dt, target_freq, key,
inputs, angles, tols, base_scale):
ms_per_second = 1000 # ms/s
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
ms_per_event = 1 / events_per_ms # ms/e
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
angle_per_event = 2 * jnp.pi # rad / e
angle_per_timestep = angle_per_event / time_step_per_event # rad / e
# * e/ts -> rad / ts
key, subkey = random.split(key, 2)
# scatter = random.uniform(subkey, angles.shape, minval=0.5,
# maxval=1.5) * base_scale

scatter = ((random.normal(subkey, angles.shape) * 0.2) + 1) * base_scale
scattered_update = angle_per_timestep * scatter
scaled_scattered_update = scattered_update * inputs

updated_angles = angles + scaled_scattered_update
outputs = jnp.where(updated_angles > angle_per_event, 1., 0.)
updated_angles = jnp.where(updated_angles > angle_per_event,
updated_angles - angle_per_event,
updated_angles)
tols = tols * (1. - outputs) + t * outputs

return outputs, tols, key, updated_angles

@resolver(_advance_state)
def advance_state(self, outputs, tols, key, angles):
self.outputs.set(outputs)
self.tols.set(tols)
self.key.set(key)
self.angles.set(angles)

@staticmethod
def _reset(batch_size, n_units, key, target_freq):
restVals = jnp.zeros((batch_size, n_units))
key, subkey = random.split(key, 2)
return restVals, restVals, restVals, restVals, key

@resolver(_reset)
def reset(self, inputs, outputs, tols, angles, key):
self.inputs.set(inputs)
self.outputs.set(outputs)
self.tols.set(tols)
self.key.set(key)
self.angles.set(angles)

def save(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name, key=self.key.value)

def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.key.set(data['key'])

@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "Phasor - Produces input at a fairly regular "
"intervals with small amounts of noise)"
}
compartment_props = {
"inputs":
{"inputs": "Takes in external input signal values"},
"states":
{"key": "JAX PRNG key",
"angles": "The current angle of the phasor"},
"outputs":
{"tols": "Time-of-last-spike",
"outputs": "Binary spike values emitted at time t"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"batch_size": "Batch size dimension of this component",
"target_freq": "Maximum spike frequency of the train produced",
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if
Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
for c in comps:
stats = tensorstats(getattr(self, c).value)
if stats is not None:
line = [f"{k}: {v}" for k, v in stats.items()]
line = ", ".join(line)
else:
line = "None"
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
return lines


0 comments on commit 77f347f

Please sign in to comment.