|
| 1 | +"""Compare the SAE implementation to Neel's 1L Implementation. |
| 2 | +
|
| 3 | +https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py |
| 4 | +""" |
| 5 | +import torch |
| 6 | +from torch import nn |
| 7 | + |
| 8 | +from sparse_autoencoder.autoencoder.model import SparseAutoencoder |
| 9 | + |
| 10 | + |
| 11 | +class NeelAutoencoder(nn.Module): |
| 12 | + """Neel's 1L autoencoder implementation.""" |
| 13 | + |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + d_hidden: int, |
| 17 | + act_size: int, |
| 18 | + l1_coeff: float, |
| 19 | + dtype: torch.dtype = torch.float32, |
| 20 | + ) -> None: |
| 21 | + """Initialize the autoencoder.""" |
| 22 | + super().__init__() |
| 23 | + self.b_dec = nn.Parameter(torch.zeros(act_size, dtype=dtype)) |
| 24 | + self.W_enc = nn.Parameter( |
| 25 | + torch.nn.init.kaiming_uniform_(torch.empty(act_size, d_hidden, dtype=dtype)) |
| 26 | + ) |
| 27 | + self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype)) |
| 28 | + self.W_dec = nn.Parameter( |
| 29 | + torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, act_size, dtype=dtype)) |
| 30 | + ) |
| 31 | + |
| 32 | + self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) |
| 33 | + |
| 34 | + self.d_hidden = d_hidden |
| 35 | + self.l1_coeff = l1_coeff |
| 36 | + |
| 37 | + def forward( |
| 38 | + self, x: torch.Tensor |
| 39 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 40 | + """Forward pass.""" |
| 41 | + x_cent = x - self.b_dec |
| 42 | + acts = nn.functional.relu(x_cent @ self.W_enc + self.b_enc) |
| 43 | + x_reconstruct = acts @ self.W_dec + self.b_dec |
| 44 | + l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0) |
| 45 | + l1_loss = self.l1_coeff * (acts.float().abs().sum()) |
| 46 | + loss = l2_loss + l1_loss |
| 47 | + return loss, x_reconstruct, acts, l2_loss, l1_loss |
| 48 | + |
| 49 | + def make_decoder_weights_and_grad_unit_norm(self) -> None: |
| 50 | + """Make decoder weights and gradient unit norm.""" |
| 51 | + with torch.no_grad(): |
| 52 | + weight_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) |
| 53 | + weight_dec_grad_proj = (self.W_dec.grad * weight_dec_normed).sum( |
| 54 | + -1, keepdim=True |
| 55 | + ) * weight_dec_normed |
| 56 | + self.W_dec.grad -= weight_dec_grad_proj |
| 57 | + # Bugfix(?) |
| 58 | + self.W_dec.data = weight_dec_normed |
| 59 | + |
| 60 | + |
| 61 | +def test_biases_initialised_same_way() -> None: |
| 62 | + """Test that the biases are initialised the same.""" |
| 63 | + n_input_features: int = 2 |
| 64 | + n_learned_features: int = 3 |
| 65 | + l1_coefficient: float = 0.01 |
| 66 | + |
| 67 | + torch.random.manual_seed(0) |
| 68 | + autoencoder = SparseAutoencoder( |
| 69 | + n_input_features=n_input_features, |
| 70 | + n_learned_features=n_learned_features, |
| 71 | + ) |
| 72 | + |
| 73 | + torch.random.manual_seed(0) |
| 74 | + neel_autoencoder = NeelAutoencoder( |
| 75 | + d_hidden=n_learned_features, |
| 76 | + act_size=n_input_features, |
| 77 | + l1_coeff=l1_coefficient, |
| 78 | + ) |
| 79 | + |
| 80 | + assert torch.allclose(autoencoder.tied_bias, neel_autoencoder.b_dec) |
| 81 | + # Note we can't compare weights as Neel's implementation uses rotated tensors and applies |
| 82 | + # kaiming incorrectly (uses leaky relu version and incorrect fan strategy for the rotation |
| 83 | + # used). Note also that the encoder bias is initialised to zero in Neel's implementation, |
| 84 | + # whereas we use the standard PyTorch initialisation. |
| 85 | + |
| 86 | + |
| 87 | +def test_forward_pass_same_weights() -> None: |
| 88 | + """Test a forward pass with the same weights.""" |
| 89 | + n_input_features: int = 12 |
| 90 | + n_learned_features: int = 48 |
| 91 | + l1_coefficient: float = 0.01 |
| 92 | + |
| 93 | + autoencoder = SparseAutoencoder( |
| 94 | + n_input_features=n_input_features, |
| 95 | + n_learned_features=n_learned_features, |
| 96 | + ) |
| 97 | + neel_autoencoder = NeelAutoencoder( |
| 98 | + d_hidden=n_learned_features, |
| 99 | + act_size=n_input_features, |
| 100 | + l1_coeff=l1_coefficient, |
| 101 | + ) |
| 102 | + |
| 103 | + # Set the same weights |
| 104 | + autoencoder.encoder.weight.data = neel_autoencoder.W_enc.data.T |
| 105 | + autoencoder.decoder.weight.data = neel_autoencoder.W_dec.data.T |
| 106 | + autoencoder.tied_bias.data = neel_autoencoder.b_dec.data |
| 107 | + autoencoder.encoder.bias.data = neel_autoencoder.b_enc.data |
| 108 | + |
| 109 | + # Create some test data |
| 110 | + test_batch = torch.randn(4, n_input_features) |
| 111 | + learned, hidden = autoencoder.forward(test_batch) |
| 112 | + _loss, x_reconstruct, acts, _l2_loss, _l1_loss = neel_autoencoder.forward(test_batch) |
| 113 | + |
| 114 | + assert torch.allclose(learned, acts) |
| 115 | + assert torch.allclose(hidden, x_reconstruct) |
| 116 | + |
| 117 | + |
| 118 | +def test_unit_norm_weights() -> None: |
| 119 | + """Test that the decoder weights are unit normalized in the same way.""" |
| 120 | + n_input_features: int = 2 |
| 121 | + n_learned_features: int = 4 |
| 122 | + l1_coefficient: float = 0.01 |
| 123 | + |
| 124 | + autoencoder = SparseAutoencoder( |
| 125 | + n_input_features=n_input_features, |
| 126 | + n_learned_features=n_learned_features, |
| 127 | + ) |
| 128 | + neel_autoencoder = NeelAutoencoder( |
| 129 | + d_hidden=n_learned_features, |
| 130 | + act_size=n_input_features, |
| 131 | + l1_coeff=l1_coefficient, |
| 132 | + ) |
| 133 | + pre_unit_norm_weights = autoencoder.decoder.weight.clone() |
| 134 | + pre_unit_norm_neel_weights = neel_autoencoder.W_dec.clone() |
| 135 | + |
| 136 | + # Set the same decoder weights |
| 137 | + decoder_weights = torch.rand_like(autoencoder.decoder.weight) |
| 138 | + autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore |
| 139 | + neel_autoencoder.W_dec.data = decoder_weights.T |
| 140 | + |
| 141 | + # Do a forward & backward pass so we have gradients |
| 142 | + test_batch = torch.randn(4, n_input_features) |
| 143 | + _learned, decoded = autoencoder.forward(test_batch) |
| 144 | + decoded.sum().backward() |
| 145 | + decoded = neel_autoencoder.forward(test_batch)[1] |
| 146 | + decoded.sum().backward() |
| 147 | + |
| 148 | + # Apply the unit norm |
| 149 | + autoencoder.decoder.constrain_weights_unit_norm() |
| 150 | + neel_autoencoder.make_decoder_weights_and_grad_unit_norm() |
| 151 | + |
| 152 | + # Check the decoder weights are the same with both models |
| 153 | + assert torch.allclose(autoencoder.decoder.weight, neel_autoencoder.W_dec.T) |
| 154 | + |
| 155 | + # Check the trivial case that the weights haven't just stayed the same as before the unit norm |
| 156 | + assert not torch.allclose(autoencoder.decoder.weight, pre_unit_norm_weights) |
| 157 | + assert not torch.allclose(neel_autoencoder.W_dec, pre_unit_norm_neel_weights) |
| 158 | + |
| 159 | + |
| 160 | +def test_unit_norm_weights_grad() -> None: |
| 161 | + """Test that the decoder weights are unit normalized in the same way.""" |
| 162 | + torch.random.manual_seed(42) |
| 163 | + n_input_features: int = 2 |
| 164 | + n_learned_features: int = 4 |
| 165 | + l1_coefficient: float = 0.01 |
| 166 | + |
| 167 | + autoencoder = SparseAutoencoder( |
| 168 | + n_input_features=n_input_features, |
| 169 | + n_learned_features=n_learned_features, |
| 170 | + ) |
| 171 | + neel_autoencoder = NeelAutoencoder( |
| 172 | + d_hidden=n_learned_features, |
| 173 | + act_size=n_input_features, |
| 174 | + l1_coeff=l1_coefficient, |
| 175 | + ) |
| 176 | + |
| 177 | + # Set the same decoder weights |
| 178 | + decoder_weights = torch.rand_like(autoencoder.decoder.weight) |
| 179 | + autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore |
| 180 | + neel_autoencoder.W_dec.data = decoder_weights.T |
| 181 | + autoencoder.decoder._weight.grad = torch.zeros_like(autoencoder.decoder.weight) # noqa: SLF001 # type: ignore |
| 182 | + neel_autoencoder.W_dec.grad = torch.zeros_like(neel_autoencoder.W_dec) |
| 183 | + |
| 184 | + # Set the same tied bias weights |
| 185 | + neel_autoencoder.b_dec.data = autoencoder.tied_bias.data |
| 186 | + neel_autoencoder.b_enc.data = autoencoder.encoder.bias.data |
| 187 | + neel_autoencoder.W_enc.data = autoencoder.encoder.weight.data.T |
| 188 | + |
| 189 | + # Do a forward & backward pass so we have gradients |
| 190 | + test_batch = torch.randn(4, n_input_features) |
| 191 | + _learned, decoded = autoencoder.forward(test_batch) |
| 192 | + _loss = decoded.sum().backward() |
| 193 | + neel_decoded = neel_autoencoder.forward(test_batch)[1] |
| 194 | + _loss_neel = neel_decoded.sum().backward() |
| 195 | + |
| 196 | + # Apply the unit norm |
| 197 | + autoencoder.decoder.constrain_weights_unit_norm() |
| 198 | + neel_autoencoder.make_decoder_weights_and_grad_unit_norm() |
| 199 | + |
| 200 | + # Check the gradient weights are the same |
| 201 | + assert autoencoder.decoder.weight.grad is not None |
| 202 | + assert neel_autoencoder.W_dec.grad is not None |
| 203 | + assert torch.allclose(autoencoder.decoder.weight.grad, neel_autoencoder.W_dec.grad.T, rtol=1e-4) |
0 commit comments