Skip to content

Commit b8cc0cc

Browse files
authored
Support multi-GPU training with DataParallel (#178)
Support multi-GPU single node training by default. Note that more advanced techniques (e.g. DistributedDataParallel or DeepSpeed) will gain even more performance.
1 parent e6268fb commit b8cc0cc

File tree

8 files changed

+64
-15
lines changed

8 files changed

+64
-15
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@
125125
cmd="mkdocs serve"
126126
help="Hot reload the docs site (so changes appear instantly)"
127127

128+
[tool.poe.tasks.run]
129+
cmd="poetry run python"
130+
help="Run a python file (append with file name)"
131+
128132
[build-system]
129133
build-backend="poetry.core.masonry.api"
130134
requires=["poetry-core"]

sparse_autoencoder/activation_resampler/abstract_activation_resampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
1111
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
1212
from sparse_autoencoder.tensor_types import Axis
13+
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
1314

1415

1516
@dataclass
@@ -58,7 +59,7 @@ def step_resampler(
5859
self,
5960
batch_neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE],
6061
activation_store: TensorActivationStore,
61-
autoencoder: SparseAutoencoder,
62+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
6263
loss_fn: AbstractLoss,
6364
train_batch_size: int,
6465
) -> list[ParameterUpdateResults] | None:

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
2121
from sparse_autoencoder.tensor_types import Axis
2222
from sparse_autoencoder.train.utils.get_model_device import get_model_device
23+
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
2324

2425

2526
class LossInputActivationsTuple(NamedTuple):
@@ -188,7 +189,7 @@ def _get_dead_neuron_indices(
188189
def compute_loss_and_get_activations(
189190
self,
190191
store: ActivationStore,
191-
autoencoder: SparseAutoencoder,
192+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
192193
loss_fn: AbstractLoss,
193194
train_batch_size: int,
194195
) -> LossInputActivationsTuple:
@@ -421,7 +422,7 @@ def renormalize_and_scale(
421422
def resample_dead_neurons(
422423
self,
423424
activation_store: ActivationStore,
424-
autoencoder: SparseAutoencoder,
425+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
425426
loss_fn: AbstractLoss,
426427
train_batch_size: int,
427428
) -> list[ParameterUpdateResults]:
@@ -513,7 +514,7 @@ def step_resampler(
513514
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
514515
],
515516
activation_store: ActivationStore,
516-
autoencoder: SparseAutoencoder,
517+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
517518
loss_fn: AbstractLoss,
518519
train_batch_size: int,
519520
) -> list[ParameterUpdateResults] | None:

sparse_autoencoder/source_data/tests/test_text_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Pile Uncopyrighted Dataset Tests."""
22
import pytest
3-
from transformers import PreTrainedTokenizerFast
3+
from transformers import GPT2Tokenizer
44

55
from sparse_autoencoder.source_data.text_dataset import TextDataset
66

@@ -9,7 +9,7 @@
99
@pytest.mark.parametrize("context_size", [50, 250])
1010
def test_tokenized_prompts_correct_size(context_size: int) -> None:
1111
"""Test that the tokenized prompts have the correct context size."""
12-
tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
12+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
1313

1414
data = TextDataset(
1515
tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted"
@@ -31,7 +31,7 @@ def test_dataloader_correct_size_items() -> None:
3131
"""Test the dataloader returns the correct number & sized items."""
3232
batch_size = 10
3333
context_size = 250
34-
tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
34+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
3535
data = TextDataset(
3636
tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted"
3737
)

sparse_autoencoder/source_model/replace_activations_hook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from transformer_lens.hook_points import HookPoint
66

77
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
8+
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
89

910

1011
if TYPE_CHECKING:
@@ -15,7 +16,7 @@
1516
def replace_activations_hook(
1617
value: Tensor,
1718
hook: HookPoint, # noqa: ARG001
18-
sparse_autoencoder: SparseAutoencoder,
19+
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
1920
component_idx: int | None = None,
2021
) -> Tensor:
2122
"""Replace activations hook.

sparse_autoencoder/train/pipeline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
3232
from sparse_autoencoder.tensor_types import Axis
3333
from sparse_autoencoder.train.utils.get_model_device import get_model_device
34+
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
3435

3536

3637
if TYPE_CHECKING:
@@ -49,7 +50,7 @@ class Pipeline:
4950
activation_resampler: AbstractActivationResampler | None
5051
"""Activation resampler to use."""
5152

52-
autoencoder: SparseAutoencoder
53+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder]
5354
"""Sparse autoencoder to train."""
5455

5556
cache_names: list[str]
@@ -79,7 +80,7 @@ class Pipeline:
7980
source_dataset: SourceDataset
8081
"""Source dataset to generate activation data from (tokenized prompts)."""
8182

82-
source_model: HookedTransformer
83+
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer]
8384
"""Source model to get activations from."""
8485

8586
total_activations_trained_on: int = 0
@@ -95,13 +96,13 @@ def n_components(self) -> int:
9596
def __init__(
9697
self,
9798
activation_resampler: AbstractActivationResampler | None,
98-
autoencoder: SparseAutoencoder,
99+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
99100
cache_names: list[str],
100101
layer: NonNegativeInt,
101102
loss: AbstractLoss,
102103
optimizer: AbstractOptimizerWithReset,
103104
source_dataset: SourceDataset,
104-
source_model: HookedTransformer,
105+
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer],
105106
run_name: str = "sparse_autoencoder",
106107
checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY,
107108
log_frequency: PositiveInt = 100,

sparse_autoencoder/train/sweep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
RuntimeHyperparameters,
2626
SweepConfig,
2727
)
28+
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
2829

2930

3031
def setup_activation_resampler(hyperparameters: RuntimeHyperparameters) -> ActivationResampler:
@@ -239,8 +240,8 @@ def stop_layer_from_cache_names(cache_names: list[str]) -> int:
239240

240241
def run_training_pipeline(
241242
hyperparameters: RuntimeHyperparameters,
242-
source_model: HookedTransformer,
243-
autoencoder: SparseAutoencoder,
243+
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer],
244+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
244245
loss: LossReducer,
245246
optimizer: AdamWithReset,
246247
activation_resampler: ActivationResampler,
@@ -324,7 +325,7 @@ def train() -> None:
324325
run_training_pipeline(
325326
hyperparameters=hyperparameters,
326327
source_model=source_model,
327-
autoencoder=autoencoder,
328+
autoencoder=DataParallelWithModelAttributes(autoencoder),
328329
loss=loss_function,
329330
optimizer=optimizer,
330331
activation_resampler=activation_resampler,
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Data parallel utils."""
2+
from typing import Any, Generic, TypeVar
3+
4+
from torch.nn import DataParallel, Module
5+
6+
7+
T = TypeVar("T", bound=Module)
8+
9+
10+
class DataParallelWithModelAttributes(DataParallel[T], Generic[T]):
11+
"""Data parallel with access to underlying model attributes/methods.
12+
13+
Allows access to underlying model attributes/methods, which is not possible with the default
14+
`DataParallel` class. Based on:
15+
https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
16+
17+
Example:
18+
>>> from sparse_autoencoder import SparseAutoencoder, SparseAutoencoderConfig
19+
>>> model = SparseAutoencoder(SparseAutoencoderConfig(
20+
... n_input_features=2,
21+
... n_learned_features=4,
22+
... ))
23+
>>> distributed_model = DataParallelWithModelAttributes(model)
24+
>>> distributed_model.config.n_learned_features
25+
4
26+
"""
27+
28+
def __getattr__(self, name: str) -> Any: # noqa: ANN401
29+
"""Allow access to underlying model attributes/methods.
30+
31+
Args:
32+
name: Attribute/method name.
33+
34+
Returns:
35+
Attribute value/method.
36+
"""
37+
try:
38+
return super().__getattr__(name)
39+
except AttributeError:
40+
return getattr(self.module, name)

0 commit comments

Comments
 (0)