Skip to content

Commit be08405

Browse files
authored
Fix the neuron resampler approach (#141)
This was normalising across an incorrect dimension and not firing at the correct point.
1 parent 15378ab commit be08405

30 files changed

+1083
-738
lines changed

docs/content/demo.ipynb

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,18 @@
3838
},
3939
{
4040
"cell_type": "code",
41-
"execution_count": 1,
41+
"execution_count": 5,
4242
"metadata": {},
43-
"outputs": [],
43+
"outputs": [
44+
{
45+
"name": "stdout",
46+
"output_type": "stream",
47+
"text": [
48+
"The autoreload extension is already loaded. To reload it, use:\n",
49+
" %reload_ext autoreload\n"
50+
]
51+
}
52+
],
4453
"source": [
4554
"# Check if we're in Colab\n",
4655
"try:\n",
@@ -62,22 +71,23 @@
6271
},
6372
{
6473
"cell_type": "code",
65-
"execution_count": 2,
74+
"execution_count": 6,
6675
"metadata": {},
6776
"outputs": [],
6877
"source": [
6978
"import os\n",
7079
"\n",
7180
"from sparse_autoencoder import (\n",
72-
" sweep,\n",
73-
" SweepConfig,\n",
81+
" ActivationResamplerHyperparameters,\n",
7482
" Hyperparameters,\n",
75-
" SourceModelHyperparameters,\n",
76-
" Parameter,\n",
77-
" SourceDataHyperparameters,\n",
78-
" Method,\n",
7983
" LossHyperparameters,\n",
84+
" Method,\n",
8085
" OptimizerHyperparameters,\n",
86+
" Parameter,\n",
87+
" SourceDataHyperparameters,\n",
88+
" SourceModelHyperparameters,\n",
89+
" sweep,\n",
90+
" SweepConfig,\n",
8191
")\n",
8292
"import wandb\n",
8393
"\n",
@@ -103,7 +113,7 @@
103113
},
104114
{
105115
"cell_type": "code",
106-
"execution_count": 3,
116+
"execution_count": 7,
107117
"metadata": {},
108118
"outputs": [
109119
{
@@ -112,28 +122,31 @@
112122
"SweepConfig(parameters=Hyperparameters(\n",
113123
" source_data=SourceDataHyperparameters(dataset_path=Parameter(value=NeelNanda/c4-code-tokenized-2b), context_size=Parameter(value=128))\n",
114124
" source_model=SourceModelHyperparameters(name=Parameter(value=gelu-2l), hook_site=Parameter(value=mlp_out), hook_layer=Parameter(value=0), hook_dimension=Parameter(value=512), dtype=Parameter(value=float32))\n",
115-
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_resamples=Parameter(value=4), n_steps_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=819200), dead_neuron_threshold=Parameter(value=0.0))\n",
116-
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=4))\n",
117-
" loss=LossHyperparameters(l1_coefficient=Parameter(values=[0.001, 0.0001, 1e-05]))\n",
118-
" optimizer=OptimizerHyperparameters(lr=Parameter(values=[0.001, 0.0001, 1e-05]), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
119-
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=12), train_batch_size=Parameter(value=4096), max_store_size=Parameter(value=3145728), max_activations=Parameter(value=2000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=314572800), validation_number_activations=Parameter(value=1024))\n",
125+
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=197885952), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=98942976), resample_dataset_size=Parameter(value=819200), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
126+
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
127+
" loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
128+
" optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
129+
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=8192), max_store_size=Parameter(value=2998272), max_activations=Parameter(value=1999847424), checkpoint_frequency=Parameter(value=47972352), validation_frequency=Parameter(value=99999744), validation_number_activations=Parameter(value=8192))\n",
120130
" random_seed=Parameter(value=49)\n",
121-
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None, runcap=None)"
131+
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
122132
]
123133
},
124-
"execution_count": 3,
134+
"execution_count": 7,
125135
"metadata": {},
126136
"output_type": "execute_result"
127137
}
128138
],
129139
"source": [
130140
"sweep_config = SweepConfig(\n",
131141
" parameters=Hyperparameters(\n",
142+
" activation_resampler=ActivationResamplerHyperparameters(\n",
143+
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
144+
" ),\n",
132145
" loss=LossHyperparameters(\n",
133-
" l1_coefficient=Parameter(values=[1e-3, 1e-4, 1e-5]),\n",
146+
" l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
134147
" ),\n",
135148
" optimizer=OptimizerHyperparameters(\n",
136-
" lr=Parameter(values=[1e-3, 1e-4, 1e-5]),\n",
149+
" lr=Parameter(max=1e-3, min=1e-5),\n",
137150
" ),\n",
138151
" source_model=SourceModelHyperparameters(\n",
139152
" name=Parameter(\"gelu-2l\"),\n",

docs/content/flexible_demo.ipynb

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,10 @@
379379
"outputs": [],
380380
"source": [
381381
"activation_resampler = ActivationResampler(\n",
382-
" resample_interval=10_000, n_steps_collate=10_000, max_resamples=5\n",
382+
" resample_interval=10_000,\n",
383+
" n_activations_activity_collate=10_000,\n",
384+
" max_n_resamples=5,\n",
385+
" n_learned_features=autoencoder.n_learned_features,\n",
383386
")"
384387
]
385388
},
@@ -400,9 +403,24 @@
400403
},
401404
{
402405
"cell_type": "code",
403-
"execution_count": null,
406+
"execution_count": 9,
404407
"metadata": {},
405-
"outputs": [],
408+
"outputs": [
409+
{
410+
"data": {
411+
"application/vnd.jupyter.widget-view+json": {
412+
"model_id": "2fe4955deca9463dbed606c9452d518e",
413+
"version_major": 2,
414+
"version_minor": 0
415+
},
416+
"text/plain": [
417+
"Resolving data files: 0%| | 0/28 [00:00<?, ?it/s]"
418+
]
419+
},
420+
"metadata": {},
421+
"output_type": "display_data"
422+
}
423+
],
406424
"source": [
407425
"source_data = PreTokenizedDataset(\n",
408426
" dataset_path=\"NeelNanda/c4-code-tokenized-2b\", context_size=int(hyperparameters[\"context_size\"])\n",
@@ -429,7 +447,7 @@
429447
},
430448
{
431449
"cell_type": "code",
432-
"execution_count": null,
450+
"execution_count": 10,
433451
"metadata": {},
434452
"outputs": [],
435453
"source": [

0 commit comments

Comments
 (0)