Skip to content

Commit

Permalink
cleaned up dense-syn
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jun 30, 2024
1 parent 89a3434 commit e678750
Showing 1 changed file with 4 additions and 25 deletions.
29 changes: 4 additions & 25 deletions ngclearn/components/synapses/denseSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,6 @@
from ngclearn.utils.weight_distribution import initialize_params
from ngcsimlib.logger import info

@jit
def _compute_layer(inp, weight, biases=0., Rscale=1.):
"""
Applies the transformation/projection induced by the synaptic efficacie
associated with this synaptic cable
Args:
inp: signal input to run through this synaptic cable
weight: this cable's synaptic value matrix
biases: this cable's bias value vector (default: 0.)
Rscale: scale factor to apply to synapses before transform applied
to input values (default: 1.)
Returns:
a projection/transformation of input "inp"
"""
return jnp.matmul(inp, weight * Rscale) + biases

class DenseSynapse(JaxComponent): ## base dense synaptic cable
"""
A dense synaptic cable; no form of synaptic evolution/adaptation
Expand All @@ -51,7 +30,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
(Default: None, which turns off/disables biases)
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
transform (Default: 1.), i.e., yields: out = ((W * in) * resist_scale) + bias
p_conn: probability of a connection existing (default: 1.); setting
this to < 1 and > 0. will result in a sparser synaptic structure
Expand Down Expand Up @@ -98,7 +77,7 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,

@staticmethod
def _advance_state(Rscale, inputs, weights, biases):
outputs = _compute_layer(inputs, weights, biases, Rscale)
outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
return outputs

@resolver(_advance_state)
Expand Down Expand Up @@ -155,12 +134,12 @@ def help(cls): ## component help function
"batch_size": "Batch size dimension of this component",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] + b",
"dynamics": "outputs = [W * inputs] * Rscale + b",
"hyperparameters": hyperparams}
return info

Expand Down

0 comments on commit e678750

Please sign in to comment.