|
| 1 | +from ngclearn import resolver, Compartment |
| 2 | +from ngclearn.components.jaxComponent import JaxComponent |
| 3 | +from ngclearn.utils import tensorstats |
| 4 | +from jax import numpy as jnp, random |
| 5 | +from ngcsimlib.logger import warn |
| 6 | + |
| 7 | +class PhasorCell(JaxComponent): |
| 8 | + """ |
| 9 | + A phasor cell that emits a pulse at a regular interval. |
| 10 | +
|
| 11 | + | --- Cell Input Compartments: --- |
| 12 | + | inputs - input (takes in external signals) |
| 13 | + | --- Cell State Compartments: --- |
| 14 | + | key - JAX PRNG key |
| 15 | + | --- Cell Output Compartments: --- |
| 16 | + | outputs - output |
| 17 | + | tols - time-of-last-spike |
| 18 | +
|
| 19 | + Args: |
| 20 | + name: the string name of this cell |
| 21 | +
|
| 22 | + n_units: number of cellular entities (neural population size) |
| 23 | +
|
| 24 | + target_freq: maximum frequency (in Hertz) of this spike train |
| 25 | + (must be > 0.) |
| 26 | + """ |
| 27 | + |
| 28 | + # Define Functions |
| 29 | + def __init__(self, name, n_units, target_freq=63.75, batch_size=1, |
| 30 | + **kwargs): |
| 31 | + super().__init__(name, **kwargs) |
| 32 | + |
| 33 | + ## Phasor meta-parameters |
| 34 | + self.target_freq = target_freq ## maximum frequency (in Hertz/Hz) |
| 35 | + |
| 36 | + ## Layer Size Setup |
| 37 | + self.batch_size = batch_size |
| 38 | + self.n_units = n_units |
| 39 | + _key, subkey = random.split(self.key.value, 2) |
| 40 | + self.key.set(_key) |
| 41 | + ## Compartment setup |
| 42 | + restVals = jnp.zeros((self.batch_size, self.n_units)) |
| 43 | + self.inputs = Compartment(restVals, |
| 44 | + display_name="Input Stimulus") # input |
| 45 | + # compartment |
| 46 | + self.outputs = Compartment(restVals, |
| 47 | + display_name="Spikes") # output compartment |
| 48 | + self.tols = Compartment(initial_value=restVals, |
| 49 | + display_name="Time-of-Last-Spike", units="ms") # time of last spike |
| 50 | + self.angles = Compartment(restVals, display_name="Angles", units="deg") |
| 51 | + # self.base_scale = random.uniform(subkey, self.angles.value.shape, |
| 52 | + # minval=0.75, maxval=1.25) |
| 53 | + # self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1) |
| 54 | + # alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1) |
| 55 | + # beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq |
| 56 | + |
| 57 | + self.base_scale = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq |
| 58 | + |
| 59 | + def validate(self, dt=None, **validation_kwargs): |
| 60 | + valid = super().validate(**validation_kwargs) |
| 61 | + if dt is None: |
| 62 | + warn(f"{self.name} requires a validation kwarg of `dt`") |
| 63 | + return False |
| 64 | + ## check for unstable combinations of dt and target-frequency |
| 65 | + # meta-params |
| 66 | + events_per_timestep = ( |
| 67 | + dt / 1000.) * self.target_freq ## |
| 68 | + # compute scaled probability |
| 69 | + if events_per_timestep > 1.: |
| 70 | + valid = False |
| 71 | + warn( |
| 72 | + f"{self.name} will be unable to make as many temporal events " |
| 73 | + f"as " |
| 74 | + f"requested! ({events_per_timestep} events/timestep) Unstable " |
| 75 | + f"combination of dt = {dt} and target_freq = " |
| 76 | + f"{self.target_freq} " |
| 77 | + f"being used!" |
| 78 | + ) |
| 79 | + return valid |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def _advance_state(t, dt, target_freq, key, |
| 83 | + inputs, angles, tols, base_scale): |
| 84 | + ms_per_second = 1000 # ms/s |
| 85 | + events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms |
| 86 | + ms_per_event = 1 / events_per_ms # ms/e |
| 87 | + time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e |
| 88 | + angle_per_event = 2 * jnp.pi # rad / e |
| 89 | + angle_per_timestep = angle_per_event / time_step_per_event # rad / e |
| 90 | + # * e/ts -> rad / ts |
| 91 | + key, subkey = random.split(key, 2) |
| 92 | + # scatter = random.uniform(subkey, angles.shape, minval=0.5, |
| 93 | + # maxval=1.5) * base_scale |
| 94 | + |
| 95 | + scatter = ((random.normal(subkey, angles.shape) * 0.2) + 1) * base_scale |
| 96 | + scattered_update = angle_per_timestep * scatter |
| 97 | + scaled_scattered_update = scattered_update * inputs |
| 98 | + |
| 99 | + updated_angles = angles + scaled_scattered_update |
| 100 | + outputs = jnp.where(updated_angles > angle_per_event, 1., 0.) |
| 101 | + updated_angles = jnp.where(updated_angles > angle_per_event, |
| 102 | + updated_angles - angle_per_event, |
| 103 | + updated_angles) |
| 104 | + tols = tols * (1. - outputs) + t * outputs |
| 105 | + |
| 106 | + return outputs, tols, key, updated_angles |
| 107 | + |
| 108 | + @resolver(_advance_state) |
| 109 | + def advance_state(self, outputs, tols, key, angles): |
| 110 | + self.outputs.set(outputs) |
| 111 | + self.tols.set(tols) |
| 112 | + self.key.set(key) |
| 113 | + self.angles.set(angles) |
| 114 | + |
| 115 | + @staticmethod |
| 116 | + def _reset(batch_size, n_units, key, target_freq): |
| 117 | + restVals = jnp.zeros((batch_size, n_units)) |
| 118 | + key, subkey = random.split(key, 2) |
| 119 | + return restVals, restVals, restVals, restVals, key |
| 120 | + |
| 121 | + @resolver(_reset) |
| 122 | + def reset(self, inputs, outputs, tols, angles, key): |
| 123 | + self.inputs.set(inputs) |
| 124 | + self.outputs.set(outputs) |
| 125 | + self.tols.set(tols) |
| 126 | + self.key.set(key) |
| 127 | + self.angles.set(angles) |
| 128 | + |
| 129 | + def save(self, directory, **kwargs): |
| 130 | + file_name = directory + "/" + self.name + ".npz" |
| 131 | + jnp.savez(file_name, key=self.key.value) |
| 132 | + |
| 133 | + def load(self, directory, **kwargs): |
| 134 | + file_name = directory + "/" + self.name + ".npz" |
| 135 | + data = jnp.load(file_name) |
| 136 | + self.key.set(data['key']) |
| 137 | + |
| 138 | + @classmethod |
| 139 | + def help(cls): ## component help function |
| 140 | + properties = { |
| 141 | + "cell_type": "Phasor - Produces input at a fairly regular " |
| 142 | + "intervals with small amounts of noise)" |
| 143 | + } |
| 144 | + compartment_props = { |
| 145 | + "inputs": |
| 146 | + {"inputs": "Takes in external input signal values"}, |
| 147 | + "states": |
| 148 | + {"key": "JAX PRNG key", |
| 149 | + "angles": "The current angle of the phasor"}, |
| 150 | + "outputs": |
| 151 | + {"tols": "Time-of-last-spike", |
| 152 | + "outputs": "Binary spike values emitted at time t"}, |
| 153 | + } |
| 154 | + hyperparams = { |
| 155 | + "n_units": "Number of neuronal cells to model in this layer", |
| 156 | + "batch_size": "Batch size dimension of this component", |
| 157 | + "target_freq": "Maximum spike frequency of the train produced", |
| 158 | + } |
| 159 | + info = {cls.__name__: properties, |
| 160 | + "compartments": compartment_props, |
| 161 | + "hyperparameters": hyperparams} |
| 162 | + return info |
| 163 | + |
| 164 | + def __repr__(self): |
| 165 | + comps = [varname for varname in dir(self) if |
| 166 | + Compartment.is_compartment(getattr(self, varname))] |
| 167 | + maxlen = max(len(c) for c in comps) + 5 |
| 168 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 169 | + for c in comps: |
| 170 | + stats = tensorstats(getattr(self, c).value) |
| 171 | + if stats is not None: |
| 172 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 173 | + line = ", ".join(line) |
| 174 | + else: |
| 175 | + line = "None" |
| 176 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 177 | + return lines |
| 178 | + |
| 179 | + |
0 commit comments