Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseConnection support #703

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def update(self) -> None:
(self.connection.wmin != -np.inf).any()
or (self.connection.wmax != np.inf).any()
) and not isinstance(self, NoOp):
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
if self.connection.w.is_sparse:
raise Exception("SparseConnection isn't supported for wmin\\wmax")
else:
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)


class NoOp(LearningRule):
Expand Down Expand Up @@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None:
if self.nu[0].any():
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w -= update
del source_s, target_x

# Post-synaptic update.
Expand All @@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None:
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
)
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += update
del source_x, target_s

super().update()
Expand Down Expand Up @@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None:

# Pre-synaptic update.
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[0] * update

# Post-synaptic update.
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[1] * update

super().update()
Expand Down Expand Up @@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None:
a_minus = torch.tensor(a_minus, device=self.connection.w.device)

# Compute weight update based on the eligibility value of the past timestep.
update = reward * self.eligibility
self.connection.w += self.nu[0] * self.reduction(update, dim=0)
update = self.reduction(reward * self.eligibility, dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[0] * update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None:
self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
self.eligibility_trace += self.eligibility / self.tc_e_trace

update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
if self.connection.w.is_sparse:
update = update.to_sparse()
# Compute weight update.
self.connection.w += (
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
)
self.connection.w += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None:
) * source_x[:, None]

# Compute weight update.
self.connection.w += self.nu[0] * reward * self.eligibility_trace
update = self.nu[0] * reward * self.eligibility_trace
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += update

super().update()
116 changes: 17 additions & 99 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,13 @@ def update(self, **kwargs) -> None:

mask = kwargs.get("mask", None)
if mask is not None:
if self.w.is_sparse:
raise Exception("Mask isn't supported for SparseConnection")
self.w.masked_fill_(mask, 0)

if self.Dales_rule is not None:
if self.w.is_sparse:
raise Exception("Dales_rule isn't supported for SparseConnection")
# weight that are negative and should be positive are set to 0
self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0
# weight that are positive and should be negative are set to 0
Expand Down Expand Up @@ -442,12 +446,19 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:

# Sum signals for each of the output/terminal neurons
# |out_signal| = [batch_size, target.n]
out_signal = conn_spikes.view(s.size(0), self.source.n, self.target.n).sum(1)
if conn_spikes.size() != torch.Size([s.size(0), self.source.n, self.target.n]):
if conn_spikes.is_sparse:
conn_spikes = conn_spikes.to_dense()
conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n)
out_signal = conn_spikes.sum(1)

if self.traces:
self.activity = out_signal

return out_signal.view(s.size(0), *self.target.shape)
if out_signal.size() != torch.Size([s.size(0)] + self.target.shape):
return out_signal.view(s.size(0), *self.target.shape)
else:
return out_signal

def compute_window(self, s: torch.Tensor) -> torch.Tensor:
# language=rst
Expand Down Expand Up @@ -1947,105 +1958,12 @@ def reset_state_variables(self) -> None:
super().reset_state_variables()


class SparseConnection(AbstractConnection):
class SparseConnection(Connection):
# language=rst
"""
Specifies sparse synapses between one or two populations of neurons.
"""

def __init__(
self,
source: Nodes,
target: Nodes,
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = None,
**kwargs,
) -> None:
# language=rst
"""
Instantiates a :code:`Connection` object with sparse weights.

:param source: A layer of nodes from which the connection originates.
:param target: A layer of nodes to which the connection connects.
:param nu: Learning rate for both pre- and post-synaptic events. It also
accepts a pair of tensors to individualize learning rates of each neuron.
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.

Keyword arguments:

:param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format
:param float sparsity: Fraction of sparse connections to use.
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param float norm: Total weight per target neuron normalization constant.
"""
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

w = kwargs.get("w", None)
self.sparsity = kwargs.get("sparsity", None)

assert (
w is not None
and self.sparsity is None
or w is None
and self.sparsity is not None
), 'Only one of "weights" or "sparsity" must be specified'

if w is None and self.sparsity is not None:
i = torch.bernoulli(
1 - self.sparsity * torch.ones(*source.shape, *target.shape)
)
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
v = torch.clamp(
torch.rand(*source.shape, *target.shape), self.wmin, self.wmax
)[i.bool()]
else:
v = (
self.wmin
+ torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin)
)[i.bool()]
w = torch.sparse.FloatTensor(i.nonzero().t(), v)
elif w is not None and self.sparsity is None:
assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)"
if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)

def compute(self, s: torch.Tensor) -> torch.Tensor:
# language=rst
"""
Compute convolutional pre-activations given spikes using layer weights.

:param s: Incoming spikes.
:return: Incoming spikes multiplied by synaptic weights (with or without
decaying spike activation).
"""
return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1)
# return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)

def update(self, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""

def normalize(self) -> None:
# language=rst
"""
Normalize weights along the first axis according to total weight per target
neuron.
"""

def reset_state_variables(self) -> None:
# language=rst
"""
Contains resetting logic for the connection.
"""
super().reset_state_variables()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.w = Parameter(self.w.to_sparse(), requires_grad=False)
Loading
Loading