|
38 | 38 | },
|
39 | 39 | {
|
40 | 40 | "cell_type": "code",
|
41 |
| - "execution_count": 1, |
| 41 | + "execution_count": 5, |
42 | 42 | "metadata": {},
|
43 |
| - "outputs": [], |
| 43 | + "outputs": [ |
| 44 | + { |
| 45 | + "name": "stdout", |
| 46 | + "output_type": "stream", |
| 47 | + "text": [ |
| 48 | + "The autoreload extension is already loaded. To reload it, use:\n", |
| 49 | + " %reload_ext autoreload\n" |
| 50 | + ] |
| 51 | + } |
| 52 | + ], |
44 | 53 | "source": [
|
45 | 54 | "# Check if we're in Colab\n",
|
46 | 55 | "try:\n",
|
|
62 | 71 | },
|
63 | 72 | {
|
64 | 73 | "cell_type": "code",
|
65 |
| - "execution_count": 2, |
| 74 | + "execution_count": 6, |
66 | 75 | "metadata": {},
|
67 | 76 | "outputs": [],
|
68 | 77 | "source": [
|
69 | 78 | "import os\n",
|
70 | 79 | "\n",
|
71 | 80 | "from sparse_autoencoder import (\n",
|
72 |
| - " sweep,\n", |
73 |
| - " SweepConfig,\n", |
| 81 | + " ActivationResamplerHyperparameters,\n", |
74 | 82 | " Hyperparameters,\n",
|
75 |
| - " SourceModelHyperparameters,\n", |
76 |
| - " Parameter,\n", |
77 |
| - " SourceDataHyperparameters,\n", |
78 |
| - " Method,\n", |
79 | 83 | " LossHyperparameters,\n",
|
| 84 | + " Method,\n", |
80 | 85 | " OptimizerHyperparameters,\n",
|
| 86 | + " Parameter,\n", |
| 87 | + " SourceDataHyperparameters,\n", |
| 88 | + " SourceModelHyperparameters,\n", |
| 89 | + " sweep,\n", |
| 90 | + " SweepConfig,\n", |
81 | 91 | ")\n",
|
82 | 92 | "import wandb\n",
|
83 | 93 | "\n",
|
|
103 | 113 | },
|
104 | 114 | {
|
105 | 115 | "cell_type": "code",
|
106 |
| - "execution_count": 3, |
| 116 | + "execution_count": 7, |
107 | 117 | "metadata": {},
|
108 | 118 | "outputs": [
|
109 | 119 | {
|
|
112 | 122 | "SweepConfig(parameters=Hyperparameters(\n",
|
113 | 123 | " source_data=SourceDataHyperparameters(dataset_path=Parameter(value=NeelNanda/c4-code-tokenized-2b), context_size=Parameter(value=128))\n",
|
114 | 124 | " source_model=SourceModelHyperparameters(name=Parameter(value=gelu-2l), hook_site=Parameter(value=mlp_out), hook_layer=Parameter(value=0), hook_dimension=Parameter(value=512), dtype=Parameter(value=float32))\n",
|
115 |
| - " activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_resamples=Parameter(value=4), n_steps_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=819200), dead_neuron_threshold=Parameter(value=0.0))\n", |
116 |
| - " autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=4))\n", |
117 |
| - " loss=LossHyperparameters(l1_coefficient=Parameter(values=[0.001, 0.0001, 1e-05]))\n", |
118 |
| - " optimizer=OptimizerHyperparameters(lr=Parameter(values=[0.001, 0.0001, 1e-05]), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n", |
119 |
| - " pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=12), train_batch_size=Parameter(value=4096), max_store_size=Parameter(value=3145728), max_activations=Parameter(value=2000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=314572800), validation_number_activations=Parameter(value=1024))\n", |
| 125 | + " activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=197885952), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=98942976), resample_dataset_size=Parameter(value=819200), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n", |
| 126 | + " autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n", |
| 127 | + " loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n", |
| 128 | + " optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n", |
| 129 | + " pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=8192), max_store_size=Parameter(value=2998272), max_activations=Parameter(value=1999847424), checkpoint_frequency=Parameter(value=47972352), validation_frequency=Parameter(value=99999744), validation_number_activations=Parameter(value=8192))\n", |
120 | 130 | " random_seed=Parameter(value=49)\n",
|
121 |
| - "), method=<Method.RANDOM: 'random'>, metric=Metric(name=total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None, runcap=None)" |
| 131 | + "), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)" |
122 | 132 | ]
|
123 | 133 | },
|
124 |
| - "execution_count": 3, |
| 134 | + "execution_count": 7, |
125 | 135 | "metadata": {},
|
126 | 136 | "output_type": "execute_result"
|
127 | 137 | }
|
128 | 138 | ],
|
129 | 139 | "source": [
|
130 | 140 | "sweep_config = SweepConfig(\n",
|
131 | 141 | " parameters=Hyperparameters(\n",
|
| 142 | + " activation_resampler=ActivationResamplerHyperparameters(\n", |
| 143 | + " threshold_is_dead_portion_fires=Parameter(1e-6),\n", |
| 144 | + " ),\n", |
132 | 145 | " loss=LossHyperparameters(\n",
|
133 |
| - " l1_coefficient=Parameter(values=[1e-3, 1e-4, 1e-5]),\n", |
| 146 | + " l1_coefficient=Parameter(max=1e-2, min=4e-3),\n", |
134 | 147 | " ),\n",
|
135 | 148 | " optimizer=OptimizerHyperparameters(\n",
|
136 |
| - " lr=Parameter(values=[1e-3, 1e-4, 1e-5]),\n", |
| 149 | + " lr=Parameter(max=1e-3, min=1e-5),\n", |
137 | 150 | " ),\n",
|
138 | 151 | " source_model=SourceModelHyperparameters(\n",
|
139 | 152 | " name=Parameter(\"gelu-2l\"),\n",
|
|
0 commit comments