You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is obviously not the most efficient way to do this - I just wanted to quickly get a feel for the Hessians, but it's always zero.
At first I thought that this might be caused by the surrogate gradients, but I get all-zero Hessians even when I change spike_fn = SurrGradSpike.apply
to spike_fn = nn.Sigmoid()
(see cell 9).
I've also changed the dtype to float64 in case these are underflows, but this didn't help, either.
Expected behavior
A non-zero matrix of second derivatives (eventually we'll be interested in torch.autograd.functional.hessian(calc_acc, torch.zeros_like(p)).
I don't have much experience with PyTorch (has anyone ported this to jax? if so I'd be happy to try it there) - what am I missing?
The text was updated successfully, but these errors were encountered:
Describe the bug
Computing the Hessian (as discussed in #60) yields the zero matrix.
To Reproduce
See https://github.com/adamhaber/snn-sound-localization/blob/compute-hessians/research/Compute%20Hessians.ipynb
Specifically, the relevant part is what happens in cells 15+; the function
calc_acc
computes the mean squared difference between the outputs of the original, trained network and a perturbed version of it (currently only W2 is perturbed, since it's smaller).This is obviously not the most efficient way to do this - I just wanted to quickly get a feel for the Hessians, but it's always zero.
At first I thought that this might be caused by the surrogate gradients, but I get all-zero Hessians even when I change
spike_fn = SurrGradSpike.apply
to
spike_fn = nn.Sigmoid()
(see cell 9).
I've also changed the dtype to float64 in case these are underflows, but this didn't help, either.
Expected behavior
A non-zero matrix of second derivatives (eventually we'll be interested in
torch.autograd.functional.hessian(calc_acc, torch.zeros_like(p))
.I don't have much experience with PyTorch (has anyone ported this to jax? if so I'd be happy to try it there) - what am I missing?
The text was updated successfully, but these errors were encountered: