Skip to content

Commit c1c2f55

Browse files
[PoC] Typed JobConfig (#767)
Incase theres any interest for type hinting of config args, at the expense of an added dependency (`tyro`). Re; #753 --------- Co-authored-by: Jayson Francis <[email protected]>
1 parent 78b0d60 commit c1c2f55

24 files changed

+950
-1061
lines changed

.ci/docker/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ blobfile
77
tabulate
88
wandb
99
fsspec
10+
tyro

docs/extension.md

+41-13
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,52 @@ This is an ongoing effort, and the level of grouping is subject to change.
3535

3636

3737
### Extending `JobConfig`
38-
[`JobConfig`](../torchtitan/config_manager.py) provides an argument `--experimental.custom_args_module`. When specified, `JobConfig` attempts to import the module provided by the argument. The imported module should contain exactly one public function. `JobConfig` executes this public function, passing its own argparser as an argument. This allows you to extend `JobConfig` with custom functionality.
3938

40-
Suppose you want to add a custom argument `--custom_args.how-is-your-day` to `JobConfig`. You can create a Python module (e.g., `custom_args.py`) with a single public function and put it to `torchtitan/experiments/your_folder/`:
39+
[`JobConfig`](../torchtitan/config_manager.py) supports custom extension through the `--experimental.custom_args_module` flag.
40+
This lets you define a custom module that extends `JobConfig` with additional fields.
4141

42-
```
43-
import argparse
42+
When specified, your custom `JobConfig` is merged with the default:
43+
- If a field exists in both, the custom config’s value replaces the default.
44+
- Fields unique to either config are retained.
45+
46+
#### Example
47+
48+
To add a custom `custom_args` section, define your own `JobConfig`:
49+
50+
```python
51+
# torchtitan/experiments/your_folder/custom_args.py
52+
from dataclasses import dataclass, field
53+
54+
@dataclass
55+
class CustomArgs:
56+
how_is_your_day: str = "good"
57+
"""Just an example."""
4458

59+
@dataclass
60+
class Training:
61+
steps: int = 500
62+
"""Replaces the default value"""
4563

46-
def extend_parser(parser: argparse.ArgumentParser) -> None:
47-
parser.add_argument(
48-
"--custom_args.how-is-your-day",
49-
type=str,
50-
default="good",
51-
help="Just an example.",
52-
)
64+
my_mini_steps: int = 10000
65+
"""New field is added"""
66+
67+
... # Original fields are preserved
68+
69+
@dataclass
70+
class JobConfig:
71+
custom_args: CustomArgs = field(default_factory=CustomArgs)
72+
training: Training= field(default_factory=Training)
5373
```
5474

55-
To utilize the custom argument, specify the following arguments when running the training script:
75+
Then run your script with:
76+
77+
```bash
78+
--experimental.custom_args_module=torchtitan.experiments.your_folder.custom_args
5679
```
57-
--experimental.custom_args_module=torchtitan.experiments.your_folder.custom_args --custom_args.how-is-your-day=wonderful
80+
81+
Or specify it in your `.toml` config:
82+
83+
```toml
84+
[experimental]
85+
custom_args_module = "torchtitan.experiments.your_folder.custom_args"
5886
```

scripts/estimate/estimation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchtitan.components.ft import init_ft_manager
1919
from torchtitan.components.lr_scheduler import build_lr_schedulers
2020
from torchtitan.components.optimizer import build_optimizers
21-
from torchtitan.config_manager import JobConfig
21+
from torchtitan.config_manager import ConfigManager, JobConfig
2222
from torchtitan.distributed import ParallelDims, utils as dist_utils
2323
from torchtitan.protocols.model_converter import build_model_converters
2424
from torchtitan.protocols.train_spec import get_train_spec
@@ -190,8 +190,8 @@ def estimate_memory(job_config: JobConfig):
190190

191191

192192
if __name__ == "__main__":
193-
config = JobConfig()
194-
config.parse_args()
193+
config_manager = ConfigManager()
194+
config = config_manager.parse_args()
195195
try:
196196
estimate_memory(config)
197197
finally:

scripts/generate/test_generate.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from torchtitan.components.metrics import build_device_memory_monitor
2828

29-
from torchtitan.config_manager import JobConfig
29+
from torchtitan.config_manager import ConfigManager
3030
from torchtitan.distributed import ParallelDims, utils as dist_utils
3131
from torchtitan.protocols.train_spec import get_train_spec
3232
from torchtitan.tools import utils
@@ -85,9 +85,8 @@ def test_generate(
8585
color = utils.Color
8686

8787
# Load configuration from toml file
88-
job_config = JobConfig()
89-
job_config.parse_args([f"--job.config_file={config_path}"])
90-
job_config._validate_config()
88+
config_manager = ConfigManager()
89+
config = config_manager.parse_args([f"--job.config_file={config_path}"])
9190

9291
if len(args.prompt) == 0:
9392
logger.warning(
@@ -100,16 +99,16 @@ def test_generate(
10099
device_module.set_device(device)
101100
device_memory_monitor = build_device_memory_monitor()
102101

103-
train_spec = get_train_spec(job_config.model.name)
102+
train_spec = get_train_spec(config.model.name)
104103

105104
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")
106105

107106
# Tokenizer setup
108-
tokenizer = train_spec.build_tokenizer_fn(job_config)
107+
tokenizer = train_spec.build_tokenizer_fn(config)
109108

110109
model_cls = train_spec.cls
111-
model_args = train_spec.config[job_config.model.flavor]
112-
model_args.update_from_config(job_config, tokenizer)
110+
model_args = train_spec.config[config.model.flavor]
111+
model_args.update_from_config(config, tokenizer)
113112

114113
init_device = "meta" if world_size > 1 else device
115114
with torch.device(init_device):
@@ -119,7 +118,7 @@ def test_generate(
119118
world_mesh = None
120119
# Init distributed env
121120
if world_size > 1:
122-
dist_utils.init_distributed(job_config)
121+
dist_utils.init_distributed(config)
123122
parallel_dims = ParallelDims(
124123
dp_replicate=1,
125124
dp_shard=-1,

tests/assets/argparser_example.py

-16
This file was deleted.
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
10+
@dataclass
11+
class CustomArgs:
12+
how_is_your_day: str = "good"
13+
"""Just an example helptext"""
14+
15+
num_days: int = 7
16+
"""Number of days in a week"""
17+
18+
19+
@dataclass
20+
class Training:
21+
steps: int = 99
22+
my_custom_steps: int = 32
23+
24+
25+
@dataclass
26+
class JobConfig:
27+
"""
28+
This is an example of how to extend the tyro parser with custom config classes.
29+
"""
30+
31+
custom_args: CustomArgs = field(default_factory=CustomArgs)
32+
training: Training = field(default_factory=Training)

tests/unit_tests/test_dataset_checkpointing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torchtitan.config_manager import JobConfig
8+
from torchtitan.config_manager import ConfigManager
99
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1010
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
1111

@@ -36,8 +36,8 @@ def test_c4_resumption(self):
3636

3737
def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):
3838
tokenizer = TikTokenizer("./tests/assets/test_tiktoken.model")
39-
config = JobConfig()
40-
config.parse_args(
39+
config_manager = ConfigManager()
40+
config = config_manager.parse_args(
4141
[
4242
"--training.dataset",
4343
dataset_name,

0 commit comments

Comments
 (0)