-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlocally_connected_multi_chan.py
212 lines (175 loc) · 7.59 KB
/
locally_connected_multi_chan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from typing import Optional, Union, Tuple, Sequence
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from bindsnet.network.nodes import Nodes
from bindsnet.learning import NoOp
from torch.nn.parameter import Parameter
from bindsnet.network.topology import AbstractConnection
class LocalConnection2D(AbstractConnection):
"""
2D Local connection between one or two population of neurons supporting multi-channel 3D inputs;
the logic is different from the BindsNet implementaion, but some lines are unchanged
"""
def __init__(
self,
source: Nodes,
target: Nodes,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
input_shape: Tuple,
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
) -> None:
"""
Instantiates a 'LocalConnection` object. Source population can be multi-channel
Neurons in the post-synaptic population are ordered by receptive field; that is,
if there are `n_conv` neurons in each post-synaptic patch, then the first
`n_conv` neurons in the post-synaptic population correspond to the first
receptive field, the second ``n_conv`` to the second receptive field, and so on.
:param source: A layer of nodes from which the connection originates.
:param target: A layer of nodes to which the connection connects.
:param kernel_size: Horizontal and vertical size of convolutional kernels.
:param stride: Horizontal and vertical stride for convolution.
:param in_channels: The number of input channels
:param out_channels: The number of output channels
:param input_shape: The 2D shape of each input channel
:param nu: Learning rate for both pre- and post-synaptic events.
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to
some rule. For now, only PostPre has been implemented for the multi-channel-input implementation
:param torch.Tensor w: Strengths of synapses.
:param torch.Tensor b: Target population bias.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param float norm: Total weight per target neuron normalization constant.
"""
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.in_channels = in_channels
self.out_channels = out_channels
self.padding = kwargs.get('padding', 0)
shape = input_shape
if kernel_size == shape:
conv_size = [1, 1]
else:
conv_size = (
(shape[0] - kernel_size[0]) // stride[0] + 1,
(shape[1] - kernel_size[1]) // stride[1] + 1,
)
self.conv_size = conv_size
self.conv_prod = int(np.prod(conv_size))
self.kernel_prod = int(np.prod(kernel_size))
assert (
target.n == out_channels * self.conv_prod
), "Target layer size must be n_filters * (kernel_size ** 2)."
w = kwargs.get("w", None)
if w is None:
w = torch.rand(
self.in_channels,
self.out_channels * self.conv_prod,
self.kernel_prod
)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_prod,
self.kernel_prod
)
if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
self.w = Parameter(w, requires_grad=False)
self.b = Parameter(kwargs.get("b", None), requires_grad=False)
self._active_update_rule = self.update_rule
self._deactivated_update_rule= NoOp(
connection=self,
)
def compute(self, s: torch.Tensor) -> torch.Tensor:
"""
Compute pre-activations given spikes using layer weights.
:param s: Incoming spikes.
:return: Incoming spikes multiplied by synaptic weights (with or without
decaying spike activation).
"""
# Compute multiplication of pre-activations by connection weights
# s: batch, ch_in, w_in, h_in => s_unfold: batch, ch_in, ch_out * w_out * h_out, k ** 2
# w: ch_in, ch_out * w_out * h_out, k ** 2
# a_post: batch, ch_in, ch_out * w_out * h_out, k ** 2 => batch, ch_out * w_out * h_out (= target.n)
self.s_unfold = s.unfold(
-2,self.kernel_size[0],self.stride[0]
).unfold(
-2,self.kernel_size[1],self.stride[1]
).reshape(
s.shape[0],
self.in_channels,
self.conv_prod,
self.kernel_prod,
).repeat(
1,
1,
self.out_channels,
1,
)
a_post = self.s_unfold.to(self.w.device) * self.w.unsqueeze(0)
return a_post.sum(-1).sum(1).view(
a_post.shape[0], self.out_channels, *self.conv_size,
)
def update(self, **kwargs) -> None:
"""
Compute connection's update rule.
"""
super().update(**kwargs)
def normalize(self) -> None:
"""
Normalize weights so each target neuron has sum of connection weights equal to
``self.norm``.
"""
if self.norm is not None:
# get a view and modify in-place
# w: ch_in, ch_out * w_out * h_out, k ** 2
w = self.w.view(
self.w.shape[0]*self.w.shape[1], self.w.shape[2]
)
for fltr in range(w.shape[0]):
w[fltr,:] *= self.norm / w[fltr,:].sum(0)
def reset_state_variables(self) -> None:
"""
Contains resetting logic for the connection.
"""
super().reset_state_variables()
self.target.reset_state_variables()
def activate_learning(self):
self.update_rule = self._active_update_rule
def deactivate_learning(self):
self.update_rule = self._deactivated_update_rule
class DepthWiseLocalConnection2D(LocalConnection2D):
# w: ch_in, ch_out * w_out * h_out, k ** 2
def __init__(self, kernel_depth, **kwargs):
super().__init__(**kwargs)
self.kernel_depth = kernel_depth
assert (
self.out_channels % (self.in_channels - self.kernel_depth + 1) ,
f"out_channels (={self.out_channels}) must be a divisible by in_channels - kernel_depth + 1 (={self.in_channels - self.kernel_depth + 1})"
)
self.n_same_depth_filters = self.out_channels // (self.in_channels - self.kernel_depth + 1)
self.mask = torch.zeros_like(self.w)
for depth in range(self.in_channels - self.kernel_depth):
self.mask[
depth:depth+self.kernel_depth,
self.n_same_depth_filters*depth*self.conv_prod:self.n_same_depth_filters*(depth+1)*self.conv_prod,
:] = 1
def update(self, **kwargs) -> None:
"""
Compute connection's update rule.
"""
if kwargs["mask"] is None:
kwargs["mask"] = self.mask
super().update(**kwargs)