From 0308b27104d8e96e9cc310f64d4d05d132c2a47f Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 23 Jan 2025 14:21:43 -0500 Subject: [PATCH 1/3] SparseConnection support --- bindsnet/learning/learning.py | 37 +++++++++--- bindsnet/network/topology.py | 105 +++------------------------------- 2 files changed, 36 insertions(+), 106 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index e2c171cd..b7ad0927 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -98,7 +98,10 @@ def update(self) -> None: (self.connection.wmin != -np.inf).any() or (self.connection.wmax != np.inf).any() ) and not isinstance(self, NoOp): - self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) + if self.connection.w.is_sparse: + raise Exception("SparseConnection isn't supported for wmin\\wmax") + else: + self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) class NoOp(LearningRule): @@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None: if self.nu[0].any(): source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0] - self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0) + update = self.reduction(torch.bmm(source_s, target_x), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w -= update del source_s, target_x # Post-synaptic update. @@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None: self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1] ) source_x = self.source.x.view(batch_size, -1).unsqueeze(2) - self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0) + update = self.reduction(torch.bmm(source_x, target_s), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w += update del source_x, target_s super().update() @@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None: # Pre-synaptic update. update = self.reduction(torch.bmm(source_s, target_x), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() self.connection.w += self.nu[0] * update # Post-synaptic update. update = self.reduction(torch.bmm(source_x, target_s), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() self.connection.w += self.nu[1] * update super().update() @@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None: a_minus = torch.tensor(a_minus, device=self.connection.w.device) # Compute weight update based on the eligibility value of the past timestep. - update = reward * self.eligibility - self.connection.w += self.nu[0] * self.reduction(update, dim=0) + update = self.reduction(reward * self.eligibility, dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w += self.nu[0] * update # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) @@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None: self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) self.eligibility_trace += self.eligibility / self.tc_e_trace + update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace + if self.connection.w.is_sparse: + update = update.to_sparse() # Compute weight update. - self.connection.w += ( - self.nu[0] * self.connection.dt * reward * self.eligibility_trace - ) + self.connection.w += update # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) @@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None: ) * source_x[:, None] # Compute weight update. - self.connection.w += self.nu[0] * reward * self.eligibility_trace + update = self.nu[0] * reward * self.eligibility_trace + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w += update super().update() diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index cb5fafa1..e2564cb6 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -126,9 +126,13 @@ def update(self, **kwargs) -> None: mask = kwargs.get("mask", None) if mask is not None: + if self.w.is_sparse: + raise Exception("Mask isn't supported for SparseConnection") self.w.masked_fill_(mask, 0) if self.Dales_rule is not None: + if self.w.is_sparse: + raise Exception("Dales_rule isn't supported for SparseConnection") # weight that are negative and should be positive are set to 0 self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0 # weight that are positive and should be negative are set to 0 @@ -1947,105 +1951,12 @@ def reset_state_variables(self) -> None: super().reset_state_variables() -class SparseConnection(AbstractConnection): +class SparseConnection(Connection): # language=rst """ Specifies sparse synapses between one or two populations of neurons. """ - def __init__( - self, - source: Nodes, - target: Nodes, - nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, - reduction: Optional[callable] = None, - weight_decay: float = None, - **kwargs, - ) -> None: - # language=rst - """ - Instantiates a :code:`Connection` object with sparse weights. - - :param source: A layer of nodes from which the connection originates. - :param target: A layer of nodes to which the connection connects. - :param nu: Learning rate for both pre- and post-synaptic events. It also - accepts a pair of tensors to individualize learning rates of each neuron. - In this case, their shape should be the same size as the connection weights. - :param reduction: Method for reducing parameter updates along the minibatch - dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. - - Keyword arguments: - - :param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format - :param float sparsity: Fraction of sparse connections to use. - :param LearningRule update_rule: Modifies connection parameters according to - some rule. - :param float wmin: Minimum allowed value on the connection weights. - :param float wmax: Maximum allowed value on the connection weights. - :param float norm: Total weight per target neuron normalization constant. - """ - super().__init__(source, target, nu, reduction, weight_decay, **kwargs) - - w = kwargs.get("w", None) - self.sparsity = kwargs.get("sparsity", None) - - assert ( - w is not None - and self.sparsity is None - or w is None - and self.sparsity is not None - ), 'Only one of "weights" or "sparsity" must be specified' - - if w is None and self.sparsity is not None: - i = torch.bernoulli( - 1 - self.sparsity * torch.ones(*source.shape, *target.shape) - ) - if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): - v = torch.clamp( - torch.rand(*source.shape, *target.shape), self.wmin, self.wmax - )[i.bool()] - else: - v = ( - self.wmin - + torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin) - )[i.bool()] - w = torch.sparse.FloatTensor(i.nonzero().t(), v) - elif w is not None and self.sparsity is None: - assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)" - if self.wmin != -np.inf or self.wmax != np.inf: - w = torch.clamp(w, self.wmin, self.wmax) - - self.w = Parameter(w, requires_grad=False) - - def compute(self, s: torch.Tensor) -> torch.Tensor: - # language=rst - """ - Compute convolutional pre-activations given spikes using layer weights. - - :param s: Incoming spikes. - :return: Incoming spikes multiplied by synaptic weights (with or without - decaying spike activation). - """ - return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1) - # return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1) - - def update(self, **kwargs) -> None: - # language=rst - """ - Compute connection's update rule. - """ - - def normalize(self) -> None: - # language=rst - """ - Normalize weights along the first axis according to total weight per target - neuron. - """ - - def reset_state_variables(self) -> None: - # language=rst - """ - Contains resetting logic for the connection. - """ - super().reset_state_variables() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.w = Parameter(self.w.to_sparse(), requires_grad=False) From d1d3e42960719fedb8a7367d853bede20a001c2c Mon Sep 17 00:00:00 2001 From: n-shevko Date: Sat, 25 Jan 2025 17:17:47 -0500 Subject: [PATCH 2/3] Test for SparseConnection --- test/network/test_connections.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/network/test_connections.py b/test/network/test_connections.py index db715e7c..961f248c 100644 --- a/test/network/test_connections.py +++ b/test/network/test_connections.py @@ -109,6 +109,12 @@ def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args, **kwargs): ): return + # SparseConnection isn't supported for wmin\\wmax + elif (conn_type == SparseConnection) and not ( + (torch.tensor(wmin, dtype=torch.float32) == -np.inf).all() + and (torch.tensor(wmax, dtype=torch.float32) == np.inf).all()): + continue + print( f"- w: {type(w).__name__}, " f"wmin: {type(wmax).__name__}, wmax: {type(wmax).__name__}" @@ -163,8 +169,9 @@ def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args, **kwargs): # tester.test_transfer() # Connections with learning ability - conn_types = [Connection, Conv2dConnection, LocalConnection] + conn_types = [Connection, SparseConnection, Conv2dConnection, LocalConnection] args = [ + [[100], [50], (100, 50)], [[100], [50], (100, 50)], [[1, 28, 28], [1, 26, 26], (1, 1, 3, 3), 3], [[1, 28, 28], [1, 26, 26], (784, 676), 3, 1, 1], From 471d4550542c09241855d73003a4de9102ab29c3 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Sat, 15 Feb 2025 12:00:40 -0500 Subject: [PATCH 3/3] Sparsity for MulticompartmentConnection --- bindsnet/network/topology.py | 11 ++- bindsnet/network/topology_features.py | 103 ++++++++++++++++++++------ 2 files changed, 91 insertions(+), 23 deletions(-) diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index e2564cb6..0f9f047c 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -446,12 +446,19 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: # Sum signals for each of the output/terminal neurons # |out_signal| = [batch_size, target.n] - out_signal = conn_spikes.view(s.size(0), self.source.n, self.target.n).sum(1) + if conn_spikes.size() != torch.Size([s.size(0), self.source.n, self.target.n]): + if conn_spikes.is_sparse: + conn_spikes = conn_spikes.to_dense() + conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n) + out_signal = conn_spikes.sum(1) if self.traces: self.activity = out_signal - return out_signal.view(s.size(0), *self.target.shape) + if out_signal.size() != torch.Size([s.size(0)] + self.target.shape): + return out_signal.view(s.size(0), *self.target.shape) + else: + return out_signal def compute_window(self, s: torch.Tensor) -> torch.Tensor: # language=rst diff --git a/bindsnet/network/topology_features.py b/bindsnet/network/topology_features.py index f99cf39f..3ba4b9c8 100644 --- a/bindsnet/network/topology_features.py +++ b/bindsnet/network/topology_features.py @@ -31,6 +31,7 @@ def __init__( enforce_polarity: Optional[bool] = False, decay: float = 0.0, parent_feature=None, + sparse: Optional[bool] = False, **kwargs, ) -> None: # language=rst @@ -47,6 +48,7 @@ def __init__( dimension :param decay: Constant multiple to decay weights by on each iteration :param parent_feature: Parent feature to inherit :code:`value` from + :param sparse: Should :code:`value` parameter be sparse tensor or not """ #### Initialize class variables #### @@ -61,6 +63,7 @@ def __init__( self.reduction = reduction self.decay = decay self.parent_feature = parent_feature + self.sparse = sparse self.kwargs = kwargs ## Backend ## @@ -119,6 +122,10 @@ def __init__( self.assert_valid_range() if value is not None: self.assert_feature_in_range() + if self.sparse: + self.value = self.value.to_sparse() + assert not getattr(self, 'enforce_polarity', False), \ + "enforce_polarity isn't supported for sparse tensors" @abstractmethod def reset_state_variables(self) -> None: @@ -161,7 +168,10 @@ def prime_feature(self, connection, device, **kwargs) -> None: # Check if values/norms are the correct shape if isinstance(self.value, torch.Tensor): - assert tuple(self.value.shape) == (connection.source.n, connection.target.n) + if self.sparse: + assert tuple(self.value.shape[1:]) == (connection.source.n, connection.target.n) + else: + assert tuple(self.value.shape) == (connection.source.n, connection.target.n) if self.norm is not None and isinstance(self.norm, torch.Tensor): assert self.norm.shape[0] == connection.target.n @@ -214,9 +224,15 @@ def normalize(self) -> None: """ if self.norm is not None: - abs_sum = self.value.sum(0).unsqueeze(0) - abs_sum[abs_sum == 0] = 1.0 - self.value *= self.norm / abs_sum + if self.sparse: + abs_sum = self.value.sum(1).to_dense() + abs_sum[abs_sum == 0] = 1.0 + abs_sum = abs_sum.unsqueeze(1).expand(-1, *self.value.shape[1:]) + self.value = self.value * (self.norm / abs_sum) + else: + abs_sum = self.value.sum(0).unsqueeze(0) + abs_sum[abs_sum == 0] = 1.0 + self.value *= self.norm / abs_sum def degrade(self) -> None: # language=rst @@ -299,11 +315,17 @@ def assert_feature_in_range(self): def assert_valid_shape(self, source_shape, target_shape, f): # Multidimensional feat - if len(f.shape) > 1: - assert f.shape == ( + if (not self.sparse and len(f.shape) > 1) or (self.sparse and len(f.shape[1:]) > 1): + if self.sparse: + f_shape = f.shape[1:] + expected = ('batch_size', source_shape, target_shape) + else: + f_shape = f.shape + expected = (source_shape, target_shape) + assert f_shape == ( source_shape, target_shape, - ), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {(source_shape, target_shape)}" + ), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {expected}" # Else assume scalar, which is a valid shape @@ -319,6 +341,7 @@ def __init__( reduction: Optional[callable] = None, decay: float = 0.0, parent_feature=None, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -336,6 +359,7 @@ def __init__( dimension :param decay: Constant multiple to decay weights by on each iteration :param parent_feature: Parent feature to inherit :code:`value` from + :param sparse: Should :code:`value` parameter be sparse tensor or not """ ### Assertions ### @@ -349,10 +373,25 @@ def __init__( reduction=reduction, decay=decay, parent_feature=parent_feature, + sparse=sparse + ) + + def sparse_bernoulli(self): + values = torch.bernoulli(self.value.values()) + mask = values != 0 + indices = self.value.indices()[:, mask] + non_zero = values[mask] + return torch.sparse_coo_tensor( + indices, + non_zero, + self.value.size() ) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: - return conn_spikes * torch.bernoulli(self.value) + if self.sparse: + return conn_spikes * self.sparse_bernoulli() + else: + return conn_spikes * torch.bernoulli(self.value) def reset_state_variables(self) -> None: pass @@ -395,12 +434,14 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ Boolean mask which determines whether or not signals are allowed to traverse certain synapses. :param name: Name of the feature :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable + :param sparse: Should :code:`value` parameter be sparse tensor or not """ ### Assertions ### @@ -419,11 +460,9 @@ def __init__( super().__init__( name=name, value=value, + sparse=sparse ) - self.name = name - self.value = value - def compute(self, conn_spikes) -> torch.Tensor: return conn_spikes * self.value @@ -505,6 +544,7 @@ def __init__( reduction: Optional[callable] = None, enforce_polarity: Optional[bool] = False, decay: float = 0.0, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -523,6 +563,7 @@ def __init__( dimension :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` :param decay: Constant multiple to decay weights by on each iteration + :param sparse: Should :code:`value` parameter be sparse tensor or not """ self.norm_frequency = norm_frequency @@ -536,6 +577,7 @@ def __init__( nu=nu, reduction=reduction, decay=decay, + sparse=sparse ) def reset_state_variables(self) -> None: @@ -589,6 +631,7 @@ def __init__( value: Union[torch.Tensor, float, int] = None, range: Optional[Sequence[float]] = None, norm: Optional[Union[torch.Tensor, float, int]] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -598,6 +641,7 @@ def __init__( :param range: Range of acceptable values for the :code:`value` parameter :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) + :param sparse: Should :code:`value` parameter be sparse tensor or not """ super().__init__( @@ -605,6 +649,7 @@ def __init__( value=value, range=[-torch.inf, +torch.inf] if range is None else range, norm=norm, + sparse=sparse ) def reset_state_variables(self) -> None: @@ -629,15 +674,17 @@ def __init__( name: str, value: Union[torch.Tensor, float, int] = None, range: Optional[Sequence[float]] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ Adds scalars to signals :param name: Name of the feature :param value: Values to scale signals by + :param sparse: Should :code:`value` parameter be sparse tensor or not """ - super().__init__(name=name, value=value, range=range) + super().__init__(name=name, value=value, range=range, sparse=sparse) def reset_state_variables(self) -> None: pass @@ -666,6 +713,7 @@ def __init__( value: Union[torch.Tensor, float, int] = None, degrade_function: callable = None, parent_feature: Optional[AbstractFeature] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -676,10 +724,11 @@ def __init__( :param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or constant to be *subtracted* from the propagating spikes. :param parent_feature: Parent feature with desired :code:`value` to inherit + :param sparse: Should :code:`value` parameter be sparse tensor or not """ # Note: parent_feature will override value. See abstract constructor - super().__init__(name=name, value=value, parent_feature=parent_feature) + super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse) self.degrade_function = degrade_function @@ -698,6 +747,7 @@ def __init__( ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.001, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -708,6 +758,7 @@ def __init__( :param value: Values to be use to build an initial mask for the synapses. :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. + :param sparse: Should :code:`value` parameter be sparse tensor or not """ # Define the ANN @@ -743,16 +794,18 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value) + super().__init__(name=name, value=value, sparse=sparse) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # Update the spike buffer if self.start_counter == False or conn_spikes.sum() > 0: self.start_counter = True - self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = ( - conn_spikes.flatten() - ) + if self.sparse: + flat_conn_spikes = conn_spikes.to_dense().flatten() + else: + flat_conn_spikes = conn_spikes.flatten() + self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes self.counter += 1 # Update the masks @@ -767,6 +820,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # self.mask = torch.clamp(self.mask, -1, 1) self.value = (self.mask > 0).float() + if self.sparse: + self.value = self.value.to_sparse() return conn_spikes * self.value @@ -788,6 +843,7 @@ def __init__( ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.01, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -798,6 +854,7 @@ def __init__( :param value: Values to be use to build an initial mask for the synapses. :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. + :param sparse: Should :code:`value` parameter be sparse tensor or not """ # Define the ANN @@ -833,16 +890,18 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value) + super().__init__(name=name, value=value, sparse=sparse) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # Update the spike buffer if self.start_counter == False or conn_spikes.sum() > 0: self.start_counter = True - self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = ( - conn_spikes.flatten() - ) + if self.sparse: + flat_conn_spikes = conn_spikes.to_dense().flatten() + else: + flat_conn_spikes = conn_spikes.flatten() + self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes self.counter += 1 # Update the masks @@ -857,6 +916,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # self.mask = torch.clamp(self.mask, -1, 1) self.value = (self.mask > 0).float() + if self.sparse: + self.value = self.value.to_sparse() return conn_spikes * self.value