Skip to content

Commit e678750

Browse files
committed
cleaned up dense-syn
1 parent 89a3434 commit e678750

File tree

1 file changed

+4
-25
lines changed

1 file changed

+4
-25
lines changed

ngclearn/components/synapses/denseSynapse.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,6 @@
55
from ngclearn.utils.weight_distribution import initialize_params
66
from ngcsimlib.logger import info
77

8-
@jit
9-
def _compute_layer(inp, weight, biases=0., Rscale=1.):
10-
"""
11-
Applies the transformation/projection induced by the synaptic efficacie
12-
associated with this synaptic cable
13-
14-
Args:
15-
inp: signal input to run through this synaptic cable
16-
17-
weight: this cable's synaptic value matrix
18-
19-
biases: this cable's bias value vector (default: 0.)
20-
21-
Rscale: scale factor to apply to synapses before transform applied
22-
to input values (default: 1.)
23-
24-
Returns:
25-
a projection/transformation of input "inp"
26-
"""
27-
return jnp.matmul(inp, weight * Rscale) + biases
28-
298
class DenseSynapse(JaxComponent): ## base dense synaptic cable
309
"""
3110
A dense synaptic cable; no form of synaptic evolution/adaptation
@@ -51,7 +30,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
5130
(Default: None, which turns off/disables biases)
5231
5332
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
54-
transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
33+
transform (Default: 1.), i.e., yields: out = ((W * in) * resist_scale) + bias
5534
5635
p_conn: probability of a connection existing (default: 1.); setting
5736
this to < 1 and > 0. will result in a sparser synaptic structure
@@ -98,7 +77,7 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
9877

9978
@staticmethod
10079
def _advance_state(Rscale, inputs, weights, biases):
101-
outputs = _compute_layer(inputs, weights, biases, Rscale)
80+
outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
10281
return outputs
10382

10483
@resolver(_advance_state)
@@ -155,12 +134,12 @@ def help(cls): ## component help function
155134
"batch_size": "Batch size dimension of this component",
156135
"weight_init": "Initialization conditions for synaptic weight (W) values",
157136
"bias_init": "Initialization conditions for bias/base-rate (b) values",
158-
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
137+
"resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
159138
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
160139
}
161140
info = {cls.__name__: properties,
162141
"compartments": compartment_props,
163-
"dynamics": "outputs = [(W * Rscale) * inputs] + b",
142+
"dynamics": "outputs = [W * inputs] * Rscale + b",
164143
"hyperparameters": hyperparams}
165144
return info
166145

0 commit comments

Comments
 (0)