5
5
from ngclearn .utils .weight_distribution import initialize_params
6
6
from ngcsimlib .logger import info
7
7
8
- @jit
9
- def _compute_layer (inp , weight , biases = 0. , Rscale = 1. ):
10
- """
11
- Applies the transformation/projection induced by the synaptic efficacie
12
- associated with this synaptic cable
13
-
14
- Args:
15
- inp: signal input to run through this synaptic cable
16
-
17
- weight: this cable's synaptic value matrix
18
-
19
- biases: this cable's bias value vector (default: 0.)
20
-
21
- Rscale: scale factor to apply to synapses before transform applied
22
- to input values (default: 1.)
23
-
24
- Returns:
25
- a projection/transformation of input "inp"
26
- """
27
- return jnp .matmul (inp , weight * Rscale ) + biases
28
-
29
8
class DenseSynapse (JaxComponent ): ## base dense synaptic cable
30
9
"""
31
10
A dense synaptic cable; no form of synaptic evolution/adaptation
@@ -51,7 +30,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
51
30
(Default: None, which turns off/disables biases)
52
31
53
32
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
54
- transform (Default: 1.), i.e., yields: out = ((W * Rscale ) * in)
33
+ transform (Default: 1.), i.e., yields: out = ((W * in ) * resist_scale) + bias
55
34
56
35
p_conn: probability of a connection existing (default: 1.); setting
57
36
this to < 1 and > 0. will result in a sparser synaptic structure
@@ -98,7 +77,7 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
98
77
99
78
@staticmethod
100
79
def _advance_state (Rscale , inputs , weights , biases ):
101
- outputs = _compute_layer ( inputs , weights , biases , Rscale )
80
+ outputs = ( jnp . matmul ( inputs , weights ) * Rscale ) + biases
102
81
return outputs
103
82
104
83
@resolver (_advance_state )
@@ -155,12 +134,12 @@ def help(cls): ## component help function
155
134
"batch_size" : "Batch size dimension of this component" ,
156
135
"weight_init" : "Initialization conditions for synaptic weight (W) values" ,
157
136
"bias_init" : "Initialization conditions for bias/base-rate (b) values" ,
158
- "resist_scale" : "Resistance level scaling factor (applied to output of transformation) " ,
137
+ "resist_scale" : "Resistance level scaling factor (Rscale); applied to output of transformation" ,
159
138
"p_conn" : "Probability of a connection existing (otherwise, it is masked to zero)"
160
139
}
161
140
info = {cls .__name__ : properties ,
162
141
"compartments" : compartment_props ,
163
- "dynamics" : "outputs = [( W * Rscale) * inputs] + b" ,
142
+ "dynamics" : "outputs = [W * inputs] * Rscale + b" ,
164
143
"hyperparameters" : hyperparams }
165
144
return info
166
145
0 commit comments