Skip to content

Commit eed29de

Browse files
authored
Add pre-download sweep support (#154)
1 parent 26ae485 commit eed29de

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

sparse_autoencoder/train/sweep.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,11 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:
155155

156156
if hyperparameters["source_data"]["pre_tokenized"]:
157157
return PreTokenizedDataset(
158-
dataset_path=hyperparameters["source_data"]["dataset_path"],
159158
context_size=hyperparameters["source_data"]["context_size"],
160159
dataset_dir=dataset_dir,
161160
dataset_files=dataset_files,
161+
dataset_path=hyperparameters["source_data"]["dataset_path"],
162+
pre_download=hyperparameters["source_data"]["pre_download"],
162163
)
163164

164165
if hyperparameters["source_data"]["tokenizer_name"] is None:
@@ -171,12 +172,13 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:
171172
tokenizer = AutoTokenizer.from_pretrained(hyperparameters["source_data"]["tokenizer_name"])
172173

173174
return TextDataset(
174-
dataset_path=hyperparameters["source_data"]["dataset_path"],
175175
context_size=hyperparameters["source_data"]["context_size"],
176-
tokenizer=tokenizer,
177176
dataset_dir=dataset_dir,
178177
dataset_files=dataset_files,
178+
dataset_path=hyperparameters["source_data"]["dataset_path"],
179179
n_processes_preprocessing=4,
180+
pre_download=hyperparameters["source_data"]["pre_download"],
181+
tokenizer=tokenizer,
180182
)
181183

182184

sparse_autoencoder/train/sweep_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,12 @@ class SourceDataHyperparameters(NestedParameter):
177177
dataset_dir: Parameter[str] | None = field(default=None)
178178
"""Dataset directory (within the HF dataset)"""
179179

180-
dataset_files: Parameter[str] | None = field(default=None)
180+
dataset_files: Parameter[list[str]] | None = field(default=None)
181181
"""Dataset files (within the HF dataset)."""
182182

183+
pre_download: Parameter[bool] = field(default=Parameter(value=False))
184+
"""Whether to pre-download the dataset."""
185+
183186
pre_tokenized: Parameter[bool] = field(default=Parameter(value=True))
184187
"""If the dataset is pre-tokenized."""
185188

@@ -209,8 +212,9 @@ class SourceDataRuntimeHyperparameters(TypedDict):
209212

210213
context_size: int
211214
dataset_dir: str | None
212-
dataset_files: str | None
215+
dataset_files: list[str] | None
213216
dataset_path: str
217+
pre_download: bool
214218
pre_tokenized: bool
215219
tokenizer_name: str | None
216220

sparse_autoencoder/train/tests/test_sweep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def dummy_hyperparameters() -> RuntimeHyperparameters:
5353
"dataset_path": "NeelNanda/c4-code-tokenized-2b",
5454
"pre_tokenized": True,
5555
"tokenizer_name": None,
56+
"pre_download": False,
5657
},
5758
"source_model": {
5859
"dtype": "float32",

sparse_autoencoder/train/utils/wandb_sweep_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def __repr__(self) -> str:
269269
float,
270270
int,
271271
str,
272+
list[str],
272273
)
273274

274275

0 commit comments

Comments
 (0)