Skip to content

Commit 45e10b7

Browse files
committed
Using string_list for model.handlers argument.
1 parent 64a5338 commit 45e10b7

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

tests/unit_tests/test_job_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def test_parse_pp_split_points(self):
116116
config.experimental.pipeline_parallel_split_points == cmdline_splits
117117
), config.experimental.pipeline_parallel_split_points
118118

119+
def test_job_config_model_handlers_split(self):
120+
config = JobConfig()
121+
config.parse_args(["--model.handlers", "float8,mxfp"])
122+
assert config.model.handlers == ["float8", "mxfp"]
123+
119124
def test_print_help(self):
120125
config = JobConfig()
121126
parser = config.parser

torchtitan/config_manager.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,22 @@
2626

2727

2828
def string_list(raw_arg):
29+
"""Comma-separated string list argument."""
2930
return raw_arg.split(",")
3031

3132

33+
def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34+
section, name = fullargname.split(".")
35+
# Split string list which are still raw strings.
36+
if (
37+
section in args_dict
38+
and name in args_dict[section]
39+
and isinstance(args_dict[section][name], str)
40+
):
41+
sec = args_dict[section]
42+
sec[name] = string_list(sec[name])
43+
44+
3245
class JobConfig:
3346
"""
3447
A helper class to manage the train configuration.
@@ -184,8 +197,9 @@ def __init__(self):
184197
)
185198
self.parser.add_argument(
186199
"--model.handlers",
187-
type=str,
188-
default="",
200+
type=string_list,
201+
nargs="+",
202+
default=[],
189203
help="""
190204
Comma separated list of handlers to apply to the model.
191205
@@ -617,19 +631,12 @@ def parse_args(self, args_list: list = sys.argv[1:]):
617631
)
618632
logger.exception(f"Error details: {str(e)}")
619633
raise e
620-
634+
635+
# Checking string-list arguments are properly split into a list
621636
# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
622-
if (
623-
"experimental" in args_dict
624-
and "pipeline_parallel_split_points" in args_dict["experimental"]
625-
and isinstance(
626-
args_dict["experimental"]["pipeline_parallel_split_points"], str
627-
)
628-
):
629-
exp = args_dict["experimental"]
630-
exp["pipeline_parallel_split_points"] = string_list(
631-
exp["pipeline_parallel_split_points"]
632-
)
637+
string_list_argnames = self._get_string_list_argument_names()
638+
for n in string_list_argnames:
639+
check_string_list_argument(args_dict, n)
633640

634641
# override args dict with cmd_args
635642
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
@@ -657,13 +664,21 @@ def _validate_config(self) -> None:
657664
assert self.model.flavor
658665
assert self.model.tokenizer_path
659666

667+
def _get_string_list_argument_names(self) -> list[str]:
668+
"""Get the parser argument names of type `string_list`."""
669+
string_list_args = [
670+
v.dest for v in self.parser._actions if v.type is string_list
671+
]
672+
return string_list_args
673+
660674
def parse_args_from_command_line(
661675
self, args_list
662676
) -> Tuple[argparse.Namespace, argparse.Namespace]:
663677
"""
664678
Parse command line arguments and return the parsed args and the command line only args
665679
"""
666680
args = self.parser.parse_args(args_list)
681+
string_list_argnames = set(self._get_string_list_argument_names())
667682

668683
# aux parser to parse the command line only args, with no defaults from main parser
669684
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
@@ -672,7 +687,7 @@ def parse_args_from_command_line(
672687
aux_parser.add_argument(
673688
"--" + arg, action="store_true" if val else "store_false"
674689
)
675-
elif arg == "experimental.pipeline_parallel_split_points":
690+
elif arg in string_list_argnames:
676691
# without this special case, type inference breaks here,
677692
# since the inferred type is just 'list' and it ends up flattening
678693
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333

3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
36-
from torchtitan.parallelisms.parallel_dims import ParallelDims
3736
from torchtitan.model_handler import parse_model_handlers
37+
from torchtitan.parallelisms.parallel_dims import ParallelDims
38+
3839

3940
def parallelize_llama(
4041
model: nn.Module,

0 commit comments

Comments
 (0)