Skip to content

Commit

Permalink
mod to rpe/mstdpet
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 16, 2024
1 parent e5c3176 commit 7b8bd7c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
20 changes: 13 additions & 7 deletions ngclearn/components/neurons/graded/rewardErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,29 @@ def __init__(self, name, n_units, alpha, ema_window_len=10,
self.reward = Compartment(restVals) ## target reward signal(s)
self.rpe = Compartment(restVals) ## reward prediction error(s)
self.accum_reward = Compartment(restVals) ## accumulated reward signal(s)
self.Ns = Compartment(jnp.zeros((self.batch_size, 1)))
self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken

@staticmethod
def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
n_ep_steps, accum_reward):
n_ep_steps, accum_reward, Ns):
## compute/update RPE and predictor values
accum_reward = accum_reward + reward
rpe = reward - mu
rpe = reward - mu/Ns #reward - mu
if use_online_predictor:
mu = mu * (1. - alpha) + reward * alpha
#mu = mu * (1. - alpha) + reward * alpha
mu = mu + reward
Ns = Ns + 1.
n_ep_steps = n_ep_steps + 1
return mu, rpe, n_ep_steps, accum_reward
return mu, rpe, n_ep_steps, accum_reward, Ns

@resolver(_advance_state)
def advance_state(self, mu, rpe, n_ep_steps, accum_reward):
def advance_state(self, mu, rpe, n_ep_steps, accum_reward, Ns):
self.mu.set(mu)
self.rpe.set(rpe)
self.n_ep_steps.set(n_ep_steps)
self.accum_reward.set(accum_reward)
self.Ns.set(Ns)

@staticmethod
def _evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
Expand All @@ -88,14 +92,16 @@ def _reset(batch_size, n_units):
rpe = restVals
accum_reward = restVals
n_ep_steps = jnp.zeros((batch_size, 1))
return mu, rpe, accum_reward, n_ep_steps
Ns = jnp.zeros((batch_size, 1))
return mu, rpe, accum_reward, n_ep_steps, Ns

@resolver(_reset)
def reset(self, mu, rpe, accum_reward, n_ep_steps):
def reset(self, mu, rpe, accum_reward, n_ep_steps, Ns):
self.mu.set(mu)
self.rpe.set(rpe)
self.accum_reward.set(accum_reward)
self.n_ep_steps.set(n_ep_steps)
self.Ns.set(Ns)

@classmethod
def help(cls): ## component help function
Expand Down
2 changes: 1 addition & 1 deletion ngclearn/components/synapses/modulated/MSTDPETSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus, tau_elg,
#dWeights = jnp.where(dWeights >= 0., dWeights * modulator, dWeights)
dWeights = jnp.where(modulator > 0.,
dWeights * modulator,
jnp.clip(dWeights, max=0.) * -modulator)
jnp.clip(dWeights, min=0.) * modulator)
else:
dWeights = eligibility * modulator

Expand Down

0 comments on commit 7b8bd7c

Please sign in to comment.