Skip to content

Commit 0bd0c9d

Browse files
authored
Add tests against Neel's anthropic paper comment implementation (#122)
1 parent 130ee59 commit 0bd0c9d

File tree

4 files changed

+349
-0
lines changed

4 files changed

+349
-0
lines changed

.vscode/cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"capturable",
1313
"categoricalwprobabilities",
1414
"circuitsvis",
15+
"coeff",
1516
"colab",
1617
"cuda",
1718
"cudnn",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
addopts=[
123123
"--doctest-modules",
124124
"--jaxtyping-packages=sparse_autoencoder,beartype.beartype",
125+
"-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning",
125126
"-s",
126127
]
127128

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Compare the SAE implementation to Neel's 1L Implementation.
2+
3+
https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py
4+
"""
5+
import torch
6+
from torch import nn
7+
8+
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
9+
10+
11+
class NeelAutoencoder(nn.Module):
12+
"""Neel's 1L autoencoder implementation."""
13+
14+
def __init__(
15+
self,
16+
d_hidden: int,
17+
act_size: int,
18+
l1_coeff: float,
19+
dtype: torch.dtype = torch.float32,
20+
) -> None:
21+
"""Initialize the autoencoder."""
22+
super().__init__()
23+
self.b_dec = nn.Parameter(torch.zeros(act_size, dtype=dtype))
24+
self.W_enc = nn.Parameter(
25+
torch.nn.init.kaiming_uniform_(torch.empty(act_size, d_hidden, dtype=dtype))
26+
)
27+
self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
28+
self.W_dec = nn.Parameter(
29+
torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, act_size, dtype=dtype))
30+
)
31+
32+
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
33+
34+
self.d_hidden = d_hidden
35+
self.l1_coeff = l1_coeff
36+
37+
def forward(
38+
self, x: torch.Tensor
39+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
40+
"""Forward pass."""
41+
x_cent = x - self.b_dec
42+
acts = nn.functional.relu(x_cent @ self.W_enc + self.b_enc)
43+
x_reconstruct = acts @ self.W_dec + self.b_dec
44+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
45+
l1_loss = self.l1_coeff * (acts.float().abs().sum())
46+
loss = l2_loss + l1_loss
47+
return loss, x_reconstruct, acts, l2_loss, l1_loss
48+
49+
def make_decoder_weights_and_grad_unit_norm(self) -> None:
50+
"""Make decoder weights and gradient unit norm."""
51+
with torch.no_grad():
52+
weight_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
53+
weight_dec_grad_proj = (self.W_dec.grad * weight_dec_normed).sum(
54+
-1, keepdim=True
55+
) * weight_dec_normed
56+
self.W_dec.grad -= weight_dec_grad_proj
57+
# Bugfix(?)
58+
self.W_dec.data = weight_dec_normed
59+
60+
61+
def test_biases_initialised_same_way() -> None:
62+
"""Test that the biases are initialised the same."""
63+
n_input_features: int = 2
64+
n_learned_features: int = 3
65+
l1_coefficient: float = 0.01
66+
67+
torch.random.manual_seed(0)
68+
autoencoder = SparseAutoencoder(
69+
n_input_features=n_input_features,
70+
n_learned_features=n_learned_features,
71+
)
72+
73+
torch.random.manual_seed(0)
74+
neel_autoencoder = NeelAutoencoder(
75+
d_hidden=n_learned_features,
76+
act_size=n_input_features,
77+
l1_coeff=l1_coefficient,
78+
)
79+
80+
assert torch.allclose(autoencoder.tied_bias, neel_autoencoder.b_dec)
81+
# Note we can't compare weights as Neel's implementation uses rotated tensors and applies
82+
# kaiming incorrectly (uses leaky relu version and incorrect fan strategy for the rotation
83+
# used). Note also that the encoder bias is initialised to zero in Neel's implementation,
84+
# whereas we use the standard PyTorch initialisation.
85+
86+
87+
def test_forward_pass_same_weights() -> None:
88+
"""Test a forward pass with the same weights."""
89+
n_input_features: int = 12
90+
n_learned_features: int = 48
91+
l1_coefficient: float = 0.01
92+
93+
autoencoder = SparseAutoencoder(
94+
n_input_features=n_input_features,
95+
n_learned_features=n_learned_features,
96+
)
97+
neel_autoencoder = NeelAutoencoder(
98+
d_hidden=n_learned_features,
99+
act_size=n_input_features,
100+
l1_coeff=l1_coefficient,
101+
)
102+
103+
# Set the same weights
104+
autoencoder.encoder.weight.data = neel_autoencoder.W_enc.data.T
105+
autoencoder.decoder.weight.data = neel_autoencoder.W_dec.data.T
106+
autoencoder.tied_bias.data = neel_autoencoder.b_dec.data
107+
autoencoder.encoder.bias.data = neel_autoencoder.b_enc.data
108+
109+
# Create some test data
110+
test_batch = torch.randn(4, n_input_features)
111+
learned, hidden = autoencoder.forward(test_batch)
112+
_loss, x_reconstruct, acts, _l2_loss, _l1_loss = neel_autoencoder.forward(test_batch)
113+
114+
assert torch.allclose(learned, acts)
115+
assert torch.allclose(hidden, x_reconstruct)
116+
117+
118+
def test_unit_norm_weights() -> None:
119+
"""Test that the decoder weights are unit normalized in the same way."""
120+
n_input_features: int = 2
121+
n_learned_features: int = 4
122+
l1_coefficient: float = 0.01
123+
124+
autoencoder = SparseAutoencoder(
125+
n_input_features=n_input_features,
126+
n_learned_features=n_learned_features,
127+
)
128+
neel_autoencoder = NeelAutoencoder(
129+
d_hidden=n_learned_features,
130+
act_size=n_input_features,
131+
l1_coeff=l1_coefficient,
132+
)
133+
pre_unit_norm_weights = autoencoder.decoder.weight.clone()
134+
pre_unit_norm_neel_weights = neel_autoencoder.W_dec.clone()
135+
136+
# Set the same decoder weights
137+
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
138+
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
139+
neel_autoencoder.W_dec.data = decoder_weights.T
140+
141+
# Do a forward & backward pass so we have gradients
142+
test_batch = torch.randn(4, n_input_features)
143+
_learned, decoded = autoencoder.forward(test_batch)
144+
decoded.sum().backward()
145+
decoded = neel_autoencoder.forward(test_batch)[1]
146+
decoded.sum().backward()
147+
148+
# Apply the unit norm
149+
autoencoder.decoder.constrain_weights_unit_norm()
150+
neel_autoencoder.make_decoder_weights_and_grad_unit_norm()
151+
152+
# Check the decoder weights are the same with both models
153+
assert torch.allclose(autoencoder.decoder.weight, neel_autoencoder.W_dec.T)
154+
155+
# Check the trivial case that the weights haven't just stayed the same as before the unit norm
156+
assert not torch.allclose(autoencoder.decoder.weight, pre_unit_norm_weights)
157+
assert not torch.allclose(neel_autoencoder.W_dec, pre_unit_norm_neel_weights)
158+
159+
160+
def test_unit_norm_weights_grad() -> None:
161+
"""Test that the decoder weights are unit normalized in the same way."""
162+
torch.random.manual_seed(42)
163+
n_input_features: int = 2
164+
n_learned_features: int = 4
165+
l1_coefficient: float = 0.01
166+
167+
autoencoder = SparseAutoencoder(
168+
n_input_features=n_input_features,
169+
n_learned_features=n_learned_features,
170+
)
171+
neel_autoencoder = NeelAutoencoder(
172+
d_hidden=n_learned_features,
173+
act_size=n_input_features,
174+
l1_coeff=l1_coefficient,
175+
)
176+
177+
# Set the same decoder weights
178+
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
179+
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
180+
neel_autoencoder.W_dec.data = decoder_weights.T
181+
autoencoder.decoder._weight.grad = torch.zeros_like(autoencoder.decoder.weight) # noqa: SLF001 # type: ignore
182+
neel_autoencoder.W_dec.grad = torch.zeros_like(neel_autoencoder.W_dec)
183+
184+
# Set the same tied bias weights
185+
neel_autoencoder.b_dec.data = autoencoder.tied_bias.data
186+
neel_autoencoder.b_enc.data = autoencoder.encoder.bias.data
187+
neel_autoencoder.W_enc.data = autoencoder.encoder.weight.data.T
188+
189+
# Do a forward & backward pass so we have gradients
190+
test_batch = torch.randn(4, n_input_features)
191+
_learned, decoded = autoencoder.forward(test_batch)
192+
_loss = decoded.sum().backward()
193+
neel_decoded = neel_autoencoder.forward(test_batch)[1]
194+
_loss_neel = neel_decoded.sum().backward()
195+
196+
# Apply the unit norm
197+
autoencoder.decoder.constrain_weights_unit_norm()
198+
neel_autoencoder.make_decoder_weights_and_grad_unit_norm()
199+
200+
# Check the gradient weights are the same
201+
assert autoencoder.decoder.weight.grad is not None
202+
assert neel_autoencoder.W_dec.grad is not None
203+
assert torch.allclose(autoencoder.decoder.weight.grad, neel_autoencoder.W_dec.grad.T, rtol=1e-4)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Tests against Neel's Autoencoder Loss.
2+
3+
Compare module output against Neel's implementation at
4+
https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py .
5+
"""
6+
from typing import TypedDict
7+
8+
import pytest
9+
import torch
10+
11+
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
12+
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
13+
from sparse_autoencoder.loss.reducer import LossReducer
14+
from sparse_autoencoder.tensor_types import (
15+
InputOutputActivationBatch,
16+
ItemTensor,
17+
LearnedActivationBatch,
18+
)
19+
20+
21+
def neel_loss(
22+
source_activations: InputOutputActivationBatch,
23+
learned_activations: LearnedActivationBatch,
24+
decoded_activations: InputOutputActivationBatch,
25+
l1_coefficient: float,
26+
) -> tuple[ItemTensor, ItemTensor, ItemTensor]:
27+
"""Neel's loss function."""
28+
l2_loss = (decoded_activations.float() - source_activations.float()).pow(2).sum(-1).mean(0)
29+
l1_loss = l1_coefficient * (learned_activations.float().abs().sum())
30+
loss = l2_loss + l1_loss
31+
return l1_loss, l2_loss, loss
32+
33+
34+
def lib_loss(
35+
source_activations: InputOutputActivationBatch,
36+
learned_activations: LearnedActivationBatch,
37+
decoded_activations: InputOutputActivationBatch,
38+
l1_coefficient: float,
39+
) -> tuple[ItemTensor, ItemTensor, ItemTensor]:
40+
"""This library's loss function."""
41+
l1_loss_fn = LearnedActivationsL1Loss(
42+
l1_coefficient=float(l1_coefficient),
43+
)
44+
l2_loss_fn = L2ReconstructionLoss()
45+
46+
loss_fn = LossReducer(l1_loss_fn, l2_loss_fn)
47+
48+
l1_loss = l1_loss_fn.forward(source_activations, learned_activations, decoded_activations)
49+
l2_loss = l2_loss_fn.forward(source_activations, learned_activations, decoded_activations)
50+
total_loss = loss_fn.forward(source_activations, learned_activations, decoded_activations)
51+
return l1_loss.sum(), l2_loss.sum(), total_loss.sum()
52+
53+
54+
class MockActivations(TypedDict):
55+
"""Mock activations."""
56+
57+
source_activations: InputOutputActivationBatch
58+
learned_activations: LearnedActivationBatch
59+
decoded_activations: InputOutputActivationBatch
60+
61+
62+
@pytest.fixture()
63+
def mock_activations() -> MockActivations:
64+
"""Create mock activations.
65+
66+
Returns:
67+
Tuple of source activations, learned activations, and decoded activations.
68+
"""
69+
source_activations = torch.rand(10, 20)
70+
learned_activations = torch.rand(10, 50)
71+
decoded_activations = torch.rand(10, 20)
72+
return {
73+
"source_activations": source_activations,
74+
"learned_activations": learned_activations,
75+
"decoded_activations": decoded_activations,
76+
}
77+
78+
79+
def test_l1_loss_the_same(mock_activations: MockActivations) -> None:
80+
"""Test that the L1 loss is the same."""
81+
l1_coefficient: float = 0.01
82+
83+
neel_l1_loss = neel_loss(
84+
source_activations=mock_activations["source_activations"],
85+
learned_activations=mock_activations["learned_activations"],
86+
decoded_activations=mock_activations["decoded_activations"],
87+
l1_coefficient=l1_coefficient,
88+
)[0]
89+
90+
lib_l1_loss = lib_loss(
91+
source_activations=mock_activations["source_activations"],
92+
learned_activations=mock_activations["learned_activations"],
93+
decoded_activations=mock_activations["decoded_activations"],
94+
l1_coefficient=l1_coefficient,
95+
)[0].sum()
96+
97+
assert torch.allclose(neel_l1_loss, lib_l1_loss)
98+
99+
100+
def test_l2_loss_the_same(mock_activations: MockActivations) -> None:
101+
"""Test that the L2 loss is the same."""
102+
l1_coefficient: float = 0.01
103+
104+
neel_l2_loss = neel_loss(
105+
source_activations=mock_activations["source_activations"],
106+
learned_activations=mock_activations["learned_activations"],
107+
decoded_activations=mock_activations["decoded_activations"],
108+
l1_coefficient=l1_coefficient,
109+
)[1]
110+
111+
lib_l2_loss = lib_loss(
112+
source_activations=mock_activations["source_activations"],
113+
learned_activations=mock_activations["learned_activations"],
114+
decoded_activations=mock_activations["decoded_activations"],
115+
l1_coefficient=l1_coefficient,
116+
)[1].sum()
117+
118+
# Fix for the fact that Neel's L2 loss is summed across the features dimension and then averaged
119+
# across the batch. By contrast for l1 it is summed across both features and batch dimensions.
120+
neel_l2_loss_fixed = neel_l2_loss * len(mock_activations["source_activations"])
121+
122+
assert torch.allclose(neel_l2_loss_fixed, lib_l2_loss)
123+
124+
125+
@pytest.mark.skip("We believe Neel's L2 approach is different to the original paper.")
126+
def test_total_loss_the_same(mock_activations: MockActivations) -> None:
127+
"""Test that the total loss is the same."""
128+
l1_coefficient: float = 0.01
129+
130+
neel_total_loss = neel_loss(
131+
source_activations=mock_activations["source_activations"],
132+
learned_activations=mock_activations["learned_activations"],
133+
decoded_activations=mock_activations["decoded_activations"],
134+
l1_coefficient=l1_coefficient,
135+
)[2].sum()
136+
137+
lib_total_loss = lib_loss(
138+
source_activations=mock_activations["source_activations"],
139+
learned_activations=mock_activations["learned_activations"],
140+
decoded_activations=mock_activations["decoded_activations"],
141+
l1_coefficient=l1_coefficient,
142+
)[2].sum()
143+
144+
assert torch.allclose(neel_total_loss, lib_total_loss)

0 commit comments

Comments
 (0)