diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 390de3ec..957689c6 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -5,27 +5,6 @@ from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info -@jit -def _compute_layer(inp, weight, biases=0., Rscale=1.): - """ - Applies the transformation/projection induced by the synaptic efficacie - associated with this synaptic cable - - Args: - inp: signal input to run through this synaptic cable - - weight: this cable's synaptic value matrix - - biases: this cable's bias value vector (default: 0.) - - Rscale: scale factor to apply to synapses before transform applied - to input values (default: 1.) - - Returns: - a projection/transformation of input "inp" - """ - return jnp.matmul(inp, weight * Rscale) + biases - class DenseSynapse(JaxComponent): ## base dense synaptic cable """ A dense synaptic cable; no form of synaptic evolution/adaptation @@ -51,7 +30,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable (Default: None, which turns off/disables biases) resist_scale: a fixed (resistance) scaling factor to apply to synaptic - transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + transform (Default: 1.), i.e., yields: out = ((W * in) * resist_scale) + bias p_conn: probability of a connection existing (default: 1.); setting this to < 1 and > 0. will result in a sparser synaptic structure @@ -98,7 +77,7 @@ def __init__(self, name, shape, weight_init=None, bias_init=None, @staticmethod def _advance_state(Rscale, inputs, weights, biases): - outputs = _compute_layer(inputs, weights, biases, Rscale) + outputs = (jnp.matmul(inputs, weights) * Rscale) + biases return outputs @resolver(_advance_state) @@ -155,12 +134,12 @@ def help(cls): ## component help function "batch_size": "Batch size dimension of this component", "weight_init": "Initialization conditions for synaptic weight (W) values", "bias_init": "Initialization conditions for bias/base-rate (b) values", - "resist_scale": "Resistance level scaling factor (applied to output of transformation)", + "resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)" } info = {cls.__name__: properties, "compartments": compartment_props, - "dynamics": "outputs = [(W * Rscale) * inputs] + b", + "dynamics": "outputs = [W * inputs] * Rscale + b", "hyperparameters": hyperparams} return info