Skip to content

Commit

Permalink
added boundary clipping to conv/deconv hebb updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 17, 2024
1 parent 5004cbf commit e26e10d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
18 changes: 13 additions & 5 deletions ngclearn/components/synapses/convolution/hebbianConvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
optimization is required (as Hebbian rules typically yield
adjustments for ascent)
update_bound: if set to non-negative value, this enforces a maximum
magnitude value to clip updates made to synapses (default: 0)
optim_type: optimization scheme to physically alter synaptic values
once an update is computed (Default: "sgd"); supported schemes
include "sgd" and "adam"
Expand All @@ -83,13 +86,14 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
# Define Functions
def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
stride=1, padding=None, resist_scale=1., w_bound=0.,
is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
batch_size=1, **kwargs):
is_nonnegative=False, w_decay=0., sign_value=1.,
update_bound=0., optim_type="sgd", batch_size=1, **kwargs):
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
bias_init=bias_init, resist_scale=resist_scale, stride=stride,
padding=padding, batch_size=batch_size, **kwargs)

self.eta = eta
self.update_bounds = update_bound
self.w_bounds = w_bound
self.w_decay = w_decay ## synaptic decay
self.is_nonnegative = is_nonnegative
Expand Down Expand Up @@ -158,14 +162,18 @@ def _compute_update(sign_value, w_decay, bias_init, stride, pad_args,
return dWeights, dBiases

@staticmethod
def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
stride, pad_args, delta_shape, pre, post, weights, biases,
opt_params):
def _evolve(opt, sign_value, update_bounds, w_decay, w_bounds, is_nonnegative,
bias_init, stride, pad_args, delta_shape, pre, post, weights,
biases, opt_params):
## calc dFilters / dBiases - update to filters and biases
dWeights, dBiases = HebbianConvSynapse._compute_update(
sign_value, w_decay, bias_init, stride, pad_args, delta_shape,
pre, post, weights
)
if update_bounds > 0.:
dWeights = jnp.clip(dWeights, -update_bounds, update_bounds)
if bias_init != None:
dBiases = jnp.clip(dBiases, -update_bounds, update_bounds)
if bias_init != None:
opt_params, [weights, biases] = opt(opt_params, [weights, biases],
[dWeights, dBiases])
Expand Down
13 changes: 11 additions & 2 deletions ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
optimization is required (as Hebbian rules typically yield
adjustments for ascent)
update_bound: if set to non-negative value, this enforces a maximum
magnitude value to clip updates made to synapses (default: 0)
optim_type: optimization scheme to physically alter synaptic values
once an update is computed (Default: "sgd"); supported schemes
include "sgd" and "adam"
Expand All @@ -81,14 +84,16 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
# Define Functions
def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
stride=1, padding=None, resist_scale=1., w_bound=0., is_nonnegative=False,
w_decay=0., sign_value=1., optim_type="sgd", batch_size=1, **kwargs):
w_decay=0., sign_value=1., update_bound=None, optim_type="sgd",
batch_size=1, **kwargs):
super().__init__(name, shape, x_shape=x_shape, filter_init=filter_init,
bias_init=bias_init, resist_scale=resist_scale,
stride=stride, padding=padding, batch_size=batch_size,
**kwargs)

self.eta = eta
self.w_bounds = w_bound
self.update_bounds = update_bound
self.w_decay = w_decay ## synaptic decay
self.is_nonnegative = is_nonnegative
self.sign_value = sign_value
Expand Down Expand Up @@ -149,13 +154,17 @@ def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding,
return dWeights, dBiases

@staticmethod
def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
def _evolve(opt, sign_value, update_bounds, w_decay, w_bounds, is_nonnegative, bias_init,
shape, stride, padding, delta_shape, pre, post, weights, biases,
opt_params):
dWeights, dBiases = HebbianDeconvSynapse._compute_update(
sign_value, w_decay, bias_init, shape, stride, padding, delta_shape,
pre, post, weights
)
if update_bounds > 0.:
dWeights = jnp.clip(dWeights, -update_bounds, update_bounds)
if bias_init != None:
dBiases = jnp.clip(dBiases, -update_bounds, update_bounds)
if bias_init != None:
opt_params, [weights, biases] = opt(opt_params, [weights, biases],
[dWeights, dBiases])
Expand Down

0 comments on commit e26e10d

Please sign in to comment.