Skip to content

Commit 29f01d0

Browse files
authored
Add save and load methods to the model (#172)
1 parent 1371f83 commit 29f01d0

23 files changed

+475
-253
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ docs/content/reference
141141

142142
# Wandb
143143
wandb/
144+
artifacts/
144145

145146
# Scratch files
146147
scratch.py

.vscode/cspell.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"autocast",
1010
"autoencoder",
1111
"autoencoders",
12+
"autoencoding",
1213
"autofix",
1314
"capturable",
1415
"categoricalwprobabilities",
@@ -76,6 +77,7 @@
7677
"optim",
7778
"penality",
7879
"perp",
80+
"pickleable",
7981
"polysemantic",
8082
"polysemantically",
8183
"polysemanticity",

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ The library is designed to be modular. By default it takes the approach from [To
4343
Monosemanticity: Decomposing Language Models With Dictionary Learning
4444
](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install
4545
the library and get started quickly. Then when you need to customise something, you can just extend
46-
the abstract class for that component (e.g. you can extend `AbstractEncoder` if you want to
47-
customise the encoder layer, and then easily drop it in the standard `SparseAutoencoder` model to
48-
keep everything else as is. Every component is fully documented, so it's nice and easy to do this.
46+
the class for that component (e.g. you can extend `SparseAutoencoder` if you want to customise the
47+
model, and then drop it back into the training pipeline. Every component is fully documented, so
48+
it's nice and easy to do this.
4949

5050
## Demo
5151

docs/content/flexible_demo.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
" Pipeline,\n",
104104
" PreTokenizedDataset,\n",
105105
" SparseAutoencoder,\n",
106+
" SparseAutoencoderConfig,\n",
106107
")\n",
107108
"import wandb\n",
108109
"\n",
@@ -235,8 +236,10 @@
235236
"source": [
236237
"expansion_factor = hyperparameters[\"expansion_factor\"]\n",
237238
"autoencoder = SparseAutoencoder(\n",
238-
" n_input_features=autoencoder_input_dim, # size of the activations we are autoencoding\n",
239-
" n_learned_features=int(autoencoder_input_dim * expansion_factor), # size of SAE\n",
239+
" SparseAutoencoderConfig(\n",
240+
" n_input_features=autoencoder_input_dim, # size of the activations we are autoencoding\n",
241+
" n_learned_features=int(autoencoder_input_dim * expansion_factor), # size of SAE\n",
242+
" )\n",
240243
").to(device)\n",
241244
"autoencoder"
242245
]

sparse_autoencoder/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Sparse Autoencoder Library."""
22
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
33
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
4-
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
4+
from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
55
from sparse_autoencoder.loss.abstract_loss import LossReductionType
66
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
77
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
@@ -77,6 +77,7 @@
7777
"SourceModelHyperparameters",
7878
"SourceModelRuntimeHyperparameters",
7979
"SparseAutoencoder",
80+
"SparseAutoencoderConfig",
8081
"sweep",
8182
"SweepConfig",
8283
"TensorActivationStore",

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
2020
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
2121
from sparse_autoencoder.tensor_types import Axis
22-
from sparse_autoencoder.train.utils import get_model_device
22+
from sparse_autoencoder.train.utils.get_model_device import get_model_device
2323

2424

2525
class LossInputActivationsTuple(NamedTuple):

sparse_autoencoder/activation_resampler/tests/test_activation_resampler.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
1010
from sparse_autoencoder.activation_store.base_store import ActivationStore
1111
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
12-
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
12+
from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
1313
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
1414
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
1515
from sparse_autoencoder.loss.reducer import LossReducer
@@ -43,9 +43,11 @@ def full_activation_store() -> ActivationStore:
4343
def autoencoder_model() -> SparseAutoencoder:
4444
"""Create a dummy autoencoder model."""
4545
return SparseAutoencoder(
46-
n_components=DEFAULT_N_COMPONENTS,
47-
n_input_features=DEFAULT_N_INPUT_FEATURES,
48-
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
46+
SparseAutoencoderConfig(
47+
n_input_features=DEFAULT_N_INPUT_FEATURES,
48+
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
49+
n_components=DEFAULT_N_COMPONENTS,
50+
)
4951
)
5052

5153

@@ -126,7 +128,7 @@ def test_more_items_than_in_store_error(
126128
):
127129
ActivationResampler(
128130
resample_dataset_size=DEFAULT_N_ACTIVATIONS_STORE + 1,
129-
n_learned_features=autoencoder_model.n_learned_features,
131+
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
130132
).compute_loss_and_get_activations(
131133
store=full_activation_store,
132134
autoencoder=autoencoder_model,
@@ -285,7 +287,7 @@ def test_no_changes_if_no_dead_neurons(
285287
resample_interval=10,
286288
n_components=DEFAULT_N_COMPONENTS,
287289
n_activations_activity_collate=10,
288-
n_learned_features=autoencoder_model.n_learned_features,
290+
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
289291
resample_dataset_size=100,
290292
)
291293
updates = resampler.step_resampler(
@@ -328,7 +330,7 @@ def test_updates_dead_neuron_parameters(
328330
resample_interval=10,
329331
n_activations_activity_collate=10,
330332
n_components=DEFAULT_N_COMPONENTS,
331-
n_learned_features=autoencoder_model.n_learned_features,
333+
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
332334
resample_dataset_size=100,
333335
)
334336
parameter_updates = resampler.step_resampler(
@@ -343,7 +345,7 @@ def test_updates_dead_neuron_parameters(
343345
# Check the updated ones have changed
344346
for component_idx, neuron_idx in dead_neurons:
345347
# Decoder
346-
decoder_weights = current_parameters["decoder._weight"]
348+
decoder_weights = current_parameters["decoder.weight"]
347349
current_dead_neuron_weights = decoder_weights[component_idx, neuron_idx]
348350
updated_dead_decoder_weights = parameter_updates[
349351
component_idx
@@ -353,7 +355,7 @@ def test_updates_dead_neuron_parameters(
353355
), "Dead decoder weights should have changed."
354356

355357
# Encoder
356-
current_dead_encoder_weights = current_parameters["encoder._weight"][
358+
current_dead_encoder_weights = current_parameters["encoder.weight"][
357359
component_idx, neuron_idx
358360
]
359361
updated_dead_encoder_weights = parameter_updates[
@@ -363,7 +365,7 @@ def test_updates_dead_neuron_parameters(
363365
current_dead_encoder_weights, updated_dead_encoder_weights
364366
), "Dead encoder weights should have changed."
365367

366-
current_dead_encoder_bias = current_parameters["encoder._bias"][
368+
current_dead_encoder_bias = current_parameters["encoder.bias"][
367369
component_idx, neuron_idx
368370
]
369371
updated_dead_encoder_bias = parameter_updates[component_idx].dead_encoder_bias_updates

sparse_autoencoder/autoencoder/abstract_autoencoder.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

sparse_autoencoder/autoencoder/components/linear_encoder.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,33 +42,18 @@ class LinearEncoder(Module):
4242

4343
_n_components: int | None
4444

45-
_weight: Float[
45+
weight: Float[
4646
Parameter,
4747
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
4848
]
49-
"""Weight parameter internal state."""
49+
"""Weight parameter.
5050
51-
_bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
52-
"""Bias parameter internal state."""
53-
54-
@property
55-
def weight(
56-
self,
57-
) -> Float[
58-
Parameter,
59-
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
60-
]:
61-
"""Weight parameter.
62-
63-
Each row in the weights matrix acts as a dictionary vector, representing a single basis
64-
element in the learned activation space.
65-
"""
66-
return self._weight
51+
Each row in the weights matrix acts as a dictionary vector, representing a single basis
52+
element in the learned activation space.
53+
"""
6754

68-
@property
69-
def bias(self) -> Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
70-
"""Bias parameter."""
71-
return self._bias
55+
bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
56+
"""Bias parameter."""
7257

7358
@property
7459
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
@@ -109,12 +94,12 @@ def __init__(
10994
self._input_features = input_features
11095
self._n_components = n_components
11196

112-
self._weight = Parameter(
97+
self.weight = Parameter(
11398
torch.empty(
11499
shape_with_optional_dimensions(n_components, learnt_features, input_features),
115100
)
116101
)
117-
self._bias = Parameter(
102+
self.bias = Parameter(
118103
torch.zeros(shape_with_optional_dimensions(n_components, learnt_features))
119104
)
120105
self.activation_function = ReLU()
@@ -125,12 +110,12 @@ def reset_parameters(self) -> None:
125110
"""Initialize or reset the parameters."""
126111
# Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and
127112
# `nonlinerity` must be changed.
128-
init.kaiming_uniform_(self._weight, nonlinearity="relu")
113+
init.kaiming_uniform_(self.weight, nonlinearity="relu")
129114

130115
# Bias (approach from nn.Linear)
131-
fan_in = self._weight.size(1)
116+
fan_in = self.weight.size(1)
132117
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
133-
init.uniform_(self._bias, -bound, bound)
118+
init.uniform_(self.bias, -bound, bound)
134119

135120
def forward(
136121
self,

sparse_autoencoder/autoencoder/components/tests/test_compare_neel_implementation.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import nn
77

8-
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
8+
from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
99

1010

1111
class NeelAutoencoder(nn.Module):
@@ -66,8 +66,10 @@ def test_biases_initialised_same_way() -> None:
6666

6767
torch.random.manual_seed(0)
6868
autoencoder = SparseAutoencoder(
69-
n_input_features=n_input_features,
70-
n_learned_features=n_learned_features,
69+
SparseAutoencoderConfig(
70+
n_input_features=n_input_features,
71+
n_learned_features=n_learned_features,
72+
)
7173
)
7274

7375
torch.random.manual_seed(0)
@@ -91,8 +93,10 @@ def test_forward_pass_same_weights() -> None:
9193
l1_coefficient: float = 0.01
9294

9395
autoencoder = SparseAutoencoder(
94-
n_input_features=n_input_features,
95-
n_learned_features=n_learned_features,
96+
SparseAutoencoderConfig(
97+
n_input_features=n_input_features,
98+
n_learned_features=n_learned_features,
99+
)
96100
)
97101
neel_autoencoder = NeelAutoencoder(
98102
d_hidden=n_learned_features,
@@ -122,8 +126,10 @@ def test_unit_norm_weights() -> None:
122126
l1_coefficient: float = 0.01
123127

124128
autoencoder = SparseAutoencoder(
125-
n_input_features=n_input_features,
126-
n_learned_features=n_learned_features,
129+
SparseAutoencoderConfig(
130+
n_input_features=n_input_features,
131+
n_learned_features=n_learned_features,
132+
)
127133
)
128134
neel_autoencoder = NeelAutoencoder(
129135
d_hidden=n_learned_features,
@@ -135,7 +141,7 @@ def test_unit_norm_weights() -> None:
135141

136142
# Set the same decoder weights
137143
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
138-
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
144+
autoencoder.decoder.weight.data = decoder_weights # type: ignore
139145
neel_autoencoder.W_dec.data = decoder_weights.T
140146

141147
# Do a forward & backward pass so we have gradients
@@ -165,8 +171,10 @@ def test_unit_norm_weights_grad() -> None:
165171
l1_coefficient: float = 0.01
166172

167173
autoencoder = SparseAutoencoder(
168-
n_input_features=n_input_features,
169-
n_learned_features=n_learned_features,
174+
SparseAutoencoderConfig(
175+
n_input_features=n_input_features,
176+
n_learned_features=n_learned_features,
177+
)
170178
)
171179
neel_autoencoder = NeelAutoencoder(
172180
d_hidden=n_learned_features,
@@ -176,9 +184,9 @@ def test_unit_norm_weights_grad() -> None:
176184

177185
# Set the same decoder weights
178186
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
179-
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
187+
autoencoder.decoder.weight.data = decoder_weights # type: ignore
180188
neel_autoencoder.W_dec.data = decoder_weights.T
181-
autoencoder.decoder._weight.grad = torch.zeros_like(autoencoder.decoder.weight) # noqa: SLF001 # type: ignore
189+
autoencoder.decoder.weight.grad = torch.zeros_like(autoencoder.decoder.weight) # type: ignore
182190
neel_autoencoder.W_dec.grad = torch.zeros_like(neel_autoencoder.W_dec)
183191

184192
# Set the same tied bias weights

0 commit comments

Comments
 (0)