26
26
27
27
28
28
def string_list (raw_arg ):
29
+ """Comma-separated string list argument."""
29
30
return raw_arg .split ("," )
30
31
31
32
33
+ def check_string_list_argument (args_dict : dict [str , any ], fullargname : str ):
34
+ section , name = fullargname .split ("." )
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict [section ]
39
+ and isinstance (args_dict [section ][name ], str )
40
+ ):
41
+ sec = args_dict [section ]
42
+ sec [name ] = string_list (sec [name ])
43
+
44
+
32
45
class JobConfig :
33
46
"""
34
47
A helper class to manage the train configuration.
@@ -184,8 +197,9 @@ def __init__(self):
184
197
)
185
198
self .parser .add_argument (
186
199
"--model.handlers" ,
187
- type = str ,
188
- default = "" ,
200
+ type = string_list ,
201
+ nargs = "+" ,
202
+ default = [],
189
203
help = """
190
204
Comma separated list of handlers to apply to the model.
191
205
@@ -617,19 +631,12 @@ def parse_args(self, args_list: list = sys.argv[1:]):
617
631
)
618
632
logger .exception (f"Error details: { str (e )} " )
619
633
raise e
620
-
634
+
635
+ # Checking string-list arguments are properly split into a list
621
636
# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
622
- if (
623
- "experimental" in args_dict
624
- and "pipeline_parallel_split_points" in args_dict ["experimental" ]
625
- and isinstance (
626
- args_dict ["experimental" ]["pipeline_parallel_split_points" ], str
627
- )
628
- ):
629
- exp = args_dict ["experimental" ]
630
- exp ["pipeline_parallel_split_points" ] = string_list (
631
- exp ["pipeline_parallel_split_points" ]
632
- )
637
+ string_list_argnames = self ._get_string_list_argument_names ()
638
+ for n in string_list_argnames :
639
+ check_string_list_argument (args_dict , n )
633
640
634
641
# override args dict with cmd_args
635
642
cmd_args_dict = self ._args_to_two_level_dict (cmd_args )
@@ -657,13 +664,21 @@ def _validate_config(self) -> None:
657
664
assert self .model .flavor
658
665
assert self .model .tokenizer_path
659
666
667
+ def _get_string_list_argument_names (self ) -> list [str ]:
668
+ """Get the parser argument names of type `string_list`."""
669
+ string_list_args = [
670
+ v .dest for v in self .parser ._actions if v .type is string_list
671
+ ]
672
+ return string_list_args
673
+
660
674
def parse_args_from_command_line (
661
675
self , args_list
662
676
) -> Tuple [argparse .Namespace , argparse .Namespace ]:
663
677
"""
664
678
Parse command line arguments and return the parsed args and the command line only args
665
679
"""
666
680
args = self .parser .parse_args (args_list )
681
+ string_list_argnames = set (self ._get_string_list_argument_names ())
667
682
668
683
# aux parser to parse the command line only args, with no defaults from main parser
669
684
aux_parser = argparse .ArgumentParser (argument_default = argparse .SUPPRESS )
@@ -672,7 +687,7 @@ def parse_args_from_command_line(
672
687
aux_parser .add_argument (
673
688
"--" + arg , action = "store_true" if val else "store_false"
674
689
)
675
- elif arg == "experimental.pipeline_parallel_split_points" :
690
+ elif arg in string_list_argnames :
676
691
# without this special case, type inference breaks here,
677
692
# since the inferred type is just 'list' and it ends up flattening
678
693
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
0 commit comments