diff --git a/tensorflow_datasets/scripts/cli/build.py b/tensorflow_datasets/scripts/cli/build.py index 509dbc7c427..5e77fc17148 100644 --- a/tensorflow_datasets/scripts/cli/build.py +++ b/tensorflow_datasets/scripts/cli/build.py @@ -15,7 +15,7 @@ """`tfds build` command.""" -import argparse +import dataclasses import functools import importlib import itertools @@ -25,76 +25,74 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union from absl import logging +import simple_parsing import tensorflow_datasets as tfds from tensorflow_datasets.scripts.cli import cli_utils # pylint: disable=logging-fstring-interpolation -def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access - """Add subparser for `build` command.""" - build_parser = parsers.add_parser( - 'build', help='Commands for downloading and preparing datasets.' - ) - build_parser.add_argument( - 'datasets', # Positional arguments - type=str, - nargs='*', - help=( - 'Name(s) of the dataset(s) to build. Default to current dir. ' - 'See https://www.tensorflow.org/datasets/cli for accepted values.' - ), - ) - build_parser.add_argument( # Also accept keyword arguments - '--datasets', - type=str, - nargs='+', - dest='datasets_keyword', - help='Datasets can also be provided as keyword argument.', - ) +@dataclasses.dataclass(frozen=True, kw_only=True) +class _AutomationGroup: + """Used by automated scripts. - cli_utils.add_debug_argument_group(build_parser) - cli_utils.add_path_argument_group(build_parser) - cli_utils.add_generation_argument_group(build_parser) - cli_utils.add_publish_argument_group(build_parser) + Attributes: + exclude_datasets: If set, generate all datasets except the one defined here. + Comma separated list of datasets to exclude. + experimental_latest_version: Build the latest Version(experiments=...) + available rather than default version. + """ - # **** Automation options **** - automation_group = build_parser.add_argument_group( - 'Automation', description='Used by automated scripts.' - ) - automation_group.add_argument( - '--exclude_datasets', - type=str, - help=( - 'If set, generate all datasets except the one defined here. ' - 'Comma separated list of datasets to exclude. ' - ), + exclude_datasets: list[str] = cli_utils.comma_separated_list_field() + experimental_latest_version: bool = False + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class CmdArgs: + """Commands for downloading and preparing datasets. + + Attributes: + datasets: Name(s) of the dataset(s) to build. Default to current dir. See + https://www.tensorflow.org/datasets/cli for accepted values. + datasets_keyword: Datasets can also be provided as keyword argument. + debug: Debug & tests options. + path: Paths options. + generation: Generation options. + publish: Publishing options. + automation: Automation options. + """ + + datasets: list[str] = simple_parsing.field( + positional=True, default_factory=list, nargs='*' ) - automation_group.add_argument( - '--experimental_latest_version', - action='store_true', - help=( - 'Build the latest Version(experiments=...) available rather than ' - 'default version.' - ), + datasets_keyword: list[str] = simple_parsing.field( + alias='datasets', default_factory=list, nargs='*' ) + debug: cli_utils.DebugGroup = simple_parsing.field(prefix='') + path: cli_utils.PathGroup = simple_parsing.field(prefix='') + generation: cli_utils.GenerationGroup = simple_parsing.field(prefix='') + publish: cli_utils.PublishGroup = simple_parsing.field(prefix='') + automation: _AutomationGroup = simple_parsing.field(prefix='') - build_parser.set_defaults(subparser_fn=_build_datasets) + def execute(self): + _build_datasets(self) -def _build_datasets(args: argparse.Namespace) -> None: +def _build_datasets(args: CmdArgs) -> None: """Build the given datasets.""" # Eventually register additional datasets imports - if args.imports: - list(importlib.import_module(m) for m in args.imports.split(',')) + if args.generation.imports: + list(importlib.import_module(m) for m in args.generation.imports) # Select datasets to generate - datasets = (args.datasets or []) + (args.datasets_keyword or []) - if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set + datasets = args.datasets + args.datasets_keyword + if ( + args.automation.exclude_datasets + ): # Generate all datasets if `--exclude_datasets` set if datasets: raise ValueError("--exclude_datasets can't be used with `datasets`") datasets = set(tfds.list_builders(with_community_datasets=False)) - set( - args.exclude_datasets.split(',') + args.automation.exclude_datasets ) datasets = sorted(datasets) # `set` is not deterministic else: @@ -102,7 +100,9 @@ def _build_datasets(args: argparse.Namespace) -> None: # Import builder classes builders_cls_and_kwargs = [ - _get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports)) + _get_builder_cls_and_kwargs( + dataset, has_imports=bool(args.generation.imports) + ) for dataset in datasets ] @@ -112,19 +112,20 @@ def _build_datasets(args: argparse.Namespace) -> None: for (builder_cls, builder_kwargs) in builders_cls_and_kwargs )) process_builder_fn = functools.partial( - _download if args.download_only else _download_and_prepare, args + _download if args.generation.download_only else _download_and_prepare, + args, ) - if args.num_processes == 1: + if args.generation.num_processes == 1: for builder in builders: process_builder_fn(builder) else: - with multiprocessing.Pool(args.num_processes) as pool: + with multiprocessing.Pool(args.generation.num_processes) as pool: pool.map(process_builder_fn, builders) def _make_builders( - args: argparse.Namespace, + args: CmdArgs, builder_cls: Type[tfds.core.DatasetBuilder], builder_kwargs: Dict[str, Any], ) -> Iterator[tfds.core.DatasetBuilder]: @@ -139,7 +140,7 @@ def _make_builders( Initialized dataset builders. """ # Eventually overwrite version - if args.experimental_latest_version: + if args.automation.experimental_latest_version: if 'version' in builder_kwargs: raise ValueError( "Can't have both `--experimental_latest` and version set (`:1.0.0`)" @@ -150,19 +151,19 @@ def _make_builders( builder_kwargs['config'] = _get_config_name( builder_cls=builder_cls, config_kwarg=builder_kwargs.get('config'), - config_name=args.config, - config_idx=args.config_idx, + config_name=args.generation.config, + config_idx=args.generation.config_idx, ) - if args.file_format: - builder_kwargs['file_format'] = args.file_format + if args.generation.file_format: + builder_kwargs['file_format'] = args.generation.file_format make_builder = functools.partial( _make_builder, builder_cls, - overwrite=args.overwrite, - fail_if_exists=args.fail_if_exists, - data_dir=args.data_dir, + overwrite=args.debug.overwrite, + fail_if_exists=args.debug.fail_if_exists, + data_dir=args.path.data_dir, **builder_kwargs, ) @@ -301,7 +302,7 @@ def _make_builder( def _download( - args: argparse.Namespace, + args: CmdArgs, builder: tfds.core.DatasetBuilder, ) -> None: """Downloads all files of the given builder.""" @@ -323,7 +324,7 @@ def _download( if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None: max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS - download_dir = args.download_dir or os.path.join( + download_dir = args.path.download_dir or os.path.join( builder._data_dir_root, 'downloads' # pylint: disable=protected-access ) dl_manager = tfds.download.DownloadManager( @@ -345,35 +346,35 @@ def _download( def _download_and_prepare( - args: argparse.Namespace, + args: CmdArgs, builder: tfds.core.DatasetBuilder, ) -> None: """Generate a single builder.""" cli_utils.download_and_prepare( builder=builder, download_config=_make_download_config(args, dataset_name=builder.name), - download_dir=args.download_dir, - publish_dir=args.publish_dir, - skip_if_published=args.skip_if_published, - overwrite=args.overwrite, + download_dir=args.path.download_dir, + publish_dir=args.publish.publish_dir, + skip_if_published=args.publish.skip_if_published, + overwrite=args.debug.overwrite, ) def _make_download_config( - args: argparse.Namespace, + args: CmdArgs, dataset_name: str, ) -> tfds.download.DownloadConfig: """Generate the download and prepare configuration.""" # Load the download config - manual_dir = args.manual_dir - if args.add_name_to_manual_dir: + manual_dir = args.path.manual_dir + if args.path.add_name_to_manual_dir: manual_dir = manual_dir / dataset_name kwargs = {} - if args.max_shard_size_mb: - kwargs['max_shard_size'] = args.max_shard_size_mb << 20 - if args.download_config: - kwargs.update(json.loads(args.download_config)) + if args.generation.max_shard_size_mb: + kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20 + if args.generation.download_config: + kwargs.update(json.loads(args.generation.download_config)) if 'download_mode' in kwargs: kwargs['download_mode'] = tfds.download.GenerateMode( @@ -381,15 +382,15 @@ def _make_download_config( ) else: kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS - if args.update_metadata_only: + if args.generation.update_metadata_only: kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO dl_config = tfds.download.DownloadConfig( - extract_dir=args.extract_dir, + extract_dir=args.path.extract_dir, manual_dir=manual_dir, - max_examples_per_split=args.max_examples_per_split, - register_checksums=args.register_checksums, - force_checksums_validation=args.force_checksums_validation, + max_examples_per_split=args.debug.max_examples_per_split, + register_checksums=args.generation.register_checksums, + force_checksums_validation=args.generation.force_checksums_validation, **kwargs, ) @@ -400,9 +401,9 @@ def _make_download_config( beam = None if beam is not None: - if args.beam_pipeline_options: + if args.generation.beam_pipeline_options: dl_config.beam_options = beam.options.pipeline_options.PipelineOptions( - flags=[f'--{opt}' for opt in args.beam_pipeline_options.split(',')] + flags=[f'--{opt}' for opt in args.generation.beam_pipeline_options] ) return dl_config diff --git a/tensorflow_datasets/scripts/cli/build_test.py b/tensorflow_datasets/scripts/cli/build_test.py index 20864e2e29d..b28867215ba 100644 --- a/tensorflow_datasets/scripts/cli/build_test.py +++ b/tensorflow_datasets/scripts/cli/build_test.py @@ -316,7 +316,7 @@ def test_download_only(): ) def test_make_download_config(args: str, download_config_kwargs): args = main._parse_flags(f'tfds build x {args}'.split()) - actual = build_lib._make_download_config(args, dataset_name='x') + actual = build_lib._make_download_config(args.command, dataset_name='x') # Ignore the beam runner actual = actual.replace(beam_runner=None) expected = tfds.download.DownloadConfig(**download_config_kwargs) diff --git a/tensorflow_datasets/scripts/cli/cli_utils.py b/tensorflow_datasets/scripts/cli/cli_utils.py index c33e9042c2f..19192259e28 100644 --- a/tensorflow_datasets/scripts/cli/cli_utils.py +++ b/tensorflow_datasets/scripts/cli/cli_utils.py @@ -15,7 +15,6 @@ """Utility functions for TFDS CLI.""" -import argparse import dataclasses import itertools import os @@ -23,6 +22,7 @@ from absl import logging from etils import epath +import simple_parsing from tensorflow_datasets.core import dataset_builder from tensorflow_datasets.core import download from tensorflow_datasets.core import file_adapters @@ -80,216 +80,133 @@ def __post_init__(self): self.ds_import = ds_import -def add_debug_argument_group(parser: argparse.ArgumentParser): - """Adds debug argument group to the parser.""" - debug_group = parser.add_argument_group( - 'Debug & tests', - description=( - '--pdb Enter post-mortem debugging mode if an exception is raised.' - ), - ) - debug_group.add_argument( - '--overwrite', - action='store_true', - help='Delete pre-existing dataset if it exists.', - ) - debug_group.add_argument( - '--fail_if_exists', - action='store_true', - default=False, - help='Fails the program if there is a pre-existing dataset.', - ) - debug_group.add_argument( - '--max_examples_per_split', - type=int, +def comma_separated_list_field(**kwargs): + """Returns a field that parses a comma-separated list of values.""" + # Need to manually parse comma-separated list of values, see: + # https://github.com/lebrice/SimpleParsing/issues/142. + return simple_parsing.field( + **kwargs, + default_factory=list, + type=lambda value: value.split(','), nargs='?', - const=1, - help=( - 'When set, only generate the first X examples (default to 1), rather' - ' than the full dataset.If set to 0, only execute the' - ' `_split_generators` (which download the original data), but skip' - ' `_generator_examples`' - ), ) -def add_path_argument_group(parser: argparse.ArgumentParser): - """Adds path argument group to the parser.""" - path_group = parser.add_argument_group('Paths') - path_group.add_argument( - '--data_dir', - type=epath.Path, +@dataclasses.dataclass(frozen=True, kw_only=True) +class DebugGroup: + """--pdb Enter post-mortem debugging mode if an exception is raised. + + Attributes: + overwrite: Delete pre-existing dataset if it exists. + fail_if_exists: Fails the program if there is a pre-existing dataset. + max_examples_per_split: When set, only generate the first X examples + (default to 1), rather than the full dataset.If set to 0, only execute the + `_split_generators` (which download the original data), but skip + `_generator_examples` + """ + + overwrite: bool = False + fail_if_exists: bool = False + max_examples_per_split: int = simple_parsing.field(nargs='?', const=1) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class PathGroup: + """Path related arguments. + + Attributes: + data_dir: Where to place datasets. Default to `~/tensorflow_datasets/` or + `TFDS_DATA_DIR` environement variable. + download_dir: Where to place downloads. Default to `/downloads/`. + extract_dir: Where to extract files. Default to `/extracted/`. + manual_dir: Where to manually download data (required for some datasets). + Default to `/manual/`. + add_name_to_manual_dir: If true, append the dataset name to the `manual_dir` + (e.g. `/manual//`. Useful to avoid collisions + if many datasets are generated. + """ + + data_dir: epath.Path = simple_parsing.field( # Should match tfds.core.constant.DATA_DIR !! - default=epath.Path( + default_factory=lambda: epath.Path( os.environ.get( 'TFDS_DATA_DIR', os.path.join(os.path.expanduser('~'), 'tensorflow_datasets'), ) - ), - help=( - 'Where to place datasets. Default to ' - '`~/tensorflow_datasets/` or `TFDS_DATA_DIR` environement variable.' - ), - ) - path_group.add_argument( - '--download_dir', - type=epath.Path, - help='Where to place downloads. Default to `/downloads/`.', - ) - path_group.add_argument( - '--extract_dir', - type=epath.Path, - help='Where to extract files. Default to `/extracted/`.', - ) - path_group.add_argument( - '--manual_dir', - type=epath.Path, - help=( - 'Where to manually download data (required for some datasets). ' - 'Default to `/manual/`.' - ), - ) - path_group.add_argument( - '--add_name_to_manual_dir', - action='store_true', - help=( - 'If true, append the dataset name to the `manual_dir` (e.g. ' - '`/manual//`. Useful to avoid collisions ' - 'if many datasets are generated.' - ), + ) ) + download_dir: epath.Path | None = None + extract_dir: epath.Path | None = None + manual_dir: epath.Path | None = None + add_name_to_manual_dir: bool = False -def add_generation_argument_group(parser: argparse.ArgumentParser): - """Adds generation argument group to the parser.""" - generation_group = parser.add_argument_group('Generation') - generation_group.add_argument( - '--download_only', - action='store_true', - help=( - 'If True, download all files but do not prepare the dataset. Uses the' - ' checksum.tsv to find out what to download. Therefore, this does not' - ' work in combination with --register_checksums.' - ), - ) - generation_group.add_argument( - '--config', - '-c', - type=str, - help=( - 'Config name to build. Build all configs if not set. Can also be a' - ' json of the kwargs forwarded to the config `__init__` (for custom' - ' configs).' - ), - ) - # We are forced to have 2 flags to avoid ambiguity when config name is - # a number (e.g. `voc/2017`) - generation_group.add_argument( - '--config_idx', - type=int, - help=( - 'Config id to build (`builder_cls.BUILDER_CONFIGS[config_idx]`). ' - 'Mutually exclusive with `--config`.' - ), - ) - generation_group.add_argument( - '--update_metadata_only', - action='store_true', - default=False, - help=( - 'If True, existing dataset_info.json is updated with metadata defined' - ' in Builder class(es). Datasets must already have been prepared.' - ), - ) - generation_group.add_argument( - '--download_config', - type=str, - help=( - 'A json of the kwargs forwarded to the config `__init__` (for custom' - ' DownloadConfigs).' - ), - ) - generation_group.add_argument( - '--imports', - '-i', - type=str, - help='Comma separated list of module to import to register datasets.', - ) - generation_group.add_argument( - '--register_checksums', - action='store_true', - help='If True, store size and checksum of downloaded files.', - ) - generation_group.add_argument( - '--force_checksums_validation', - action='store_true', - help='If True, raise an error if the checksums are not found.', - ) - # For compatibility with absl.flags (which generates --foo and --nofoo). - generation_group.add_argument( - '--noforce_checksums_validation', - dest='force_checksums_validation', - action='store_false', - help='If specified, bypass the checks on the checksums.', - ) - generation_group.add_argument( - '--beam_pipeline_options', - type=str, - # nargs='+', - help=( - 'A (comma-separated) list of flags to pass to `PipelineOptions` when' - ' preparing with Apache Beam. (see:' - ' https://www.tensorflow.org/datasets/beam_datasets). Example:' - ' `--beam_pipeline_options=job_name=my-job,project=my-project`' - ), - ) - format_values = [f.value for f in file_adapters.FileFormat] - generation_group.add_argument( - '--file_format', - type=str, - help=( - 'File format to which generate the tf-examples. ' - f'Available values: {format_values} (see `tfds.core.FileFormat`).' - ), - ) - generation_group.add_argument( - '--max_shard_size_mb', type=int, help='The max shard size in megabytes.' - ) - generation_group.add_argument( - '--num-processes', - type=int, - default=1, - help='Number of parallel build processes.', - ) +@dataclasses.dataclass(frozen=True, kw_only=True) +class GenerationGroup: + """Generation related arguments. + Attributes: + download_only: If True, download all files but do not prepare the dataset. + Uses the checksum.tsv to find out what to download. Therefore, this does + not work in combination with --register_checksums. + config: Config name to build. Build all configs if not set. Can also be a + json of the kwargs forwarded to the config `__init__` (for custom + configs). + config_idx: Config id to build (`builder_cls.BUILDER_CONFIGS[config_idx]`). + Mutually exclusive with `--config`. + update_metadata_only: If True, existing dataset_info.json is updated with + metadata defined in Builder class(es). Datasets must already have been + prepared. + download_config: A json of the kwargs forwarded to the config `__init__` + (for custom DownloadConfigs). + imports: Comma separated list of module to import to register datasets. + register_checksums: If True, store size and checksum of downloaded files. + force_checksums_validation: If True, raise an error if the checksums are not + found. + beam_pipeline_options: A (comma-separated) list of flags to pass to + `PipelineOptions` when preparing with Apache Beam. (see: + https://www.tensorflow.org/datasets/beam_datasets). Example: + `--beam_pipeline_options=job_name=my-job,project=my-project` + file_format: File format to which generate the tf-examples. + max_shard_size_mb: The max shard size in megabytes. + num_processes: Number of parallel build processes. + """ -def add_publish_argument_group(parser: argparse.ArgumentParser): - """Adds publish argument group to the parser.""" - publish_group = parser.add_argument_group( - 'Publishing', - description='Options for publishing successfully created datasets.', - ) - publish_group.add_argument( - '--publish_dir', - type=epath.Path, + download_only: bool = False + config: str | None = simple_parsing.field(default=None, alias='c') + # We are forced to have 2 flags to avoid ambiguity when config name is a + # number (e.g. `voc/2017`) + config_idx: int | None = None + update_metadata_only: bool = False + download_config: str | None = None + imports: list[str] = comma_separated_list_field(alias='i') + register_checksums: bool = False + force_checksums_validation: bool = False + beam_pipeline_options: list[str] = comma_separated_list_field() + file_format: str | None = simple_parsing.choice( + *(file_format.value for file_format in file_adapters.FileFormat), default=None, - required=False, - help=( - 'Where to optionally publish the dataset after it has been ' - 'generated successfully. Should be the root data dir under which' - 'datasets are stored. ' - 'If unspecified, dataset will not be published' - ), - ) - publish_group.add_argument( - '--skip_if_published', - action='store_true', - default=False, - help=( - 'If the dataset with the same version and config is already ' - 'published, then it will not be regenerated.' - ), ) + max_shard_size_mb: int | None = None + num_processes: int = simple_parsing.field(default=1, alias='num-processes') + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class PublishGroup: + # pyformat: disable + """Options for publishing successfully created datasets. + + Attributes: + publish_dir: Where to optionally publish the dataset after it has been + generated successfully. Should be the root data dir under which datasets + are stored. If unspecified, dataset will not be published + skip_if_published: If the dataset with the same version and config is + already published, then it will not be regenerated. + """ + # pyformat: enable + + publish_dir: epath.Path | None = None + skip_if_published: bool = False def download_and_prepare( diff --git a/tensorflow_datasets/scripts/cli/convert_format.py b/tensorflow_datasets/scripts/cli/convert_format.py index de7aeea51a8..d7859cf5807 100644 --- a/tensorflow_datasets/scripts/cli/convert_format.py +++ b/tensorflow_datasets/scripts/cli/convert_format.py @@ -25,113 +25,61 @@ ``` """ -import argparse -from collections.abc import Sequence +import dataclasses from etils import epath +import simple_parsing from tensorflow_datasets.core import file_adapters +from tensorflow_datasets.scripts.cli import cli_utils from tensorflow_datasets.scripts.cli import convert_format_utils -def add_parser_arguments(parser: argparse.ArgumentParser) -> None: - """Add arguments for `convert_format` subparser.""" - parser.add_argument( - '--root_data_dir', - type=str, - help=( - 'Root data dir that contains all datasets. All datasets and all their' - ' configs and versions that are in this folder will be converted.' - ), - required=False, - ) - parser.add_argument( - '--dataset_dir', - type=str, - help=( - 'Path where the dataset to be converted is located. Converts all' - ' configs and versions in this folder.' - ), - required=False, - ) - parser.add_argument( - '--dataset_version_dir', - type=str, - help=( - 'Path where the dataset to be converted is located. Should include' - ' config and version. Can also be a comma-separated list of paths. If' - ' multiple paths are specified, `--out_dir` should not be specified,' - ' since each dataset will be converted in the same directory as the' - ' input dataset.' - ), - required=False, - ) - parser.add_argument( - '--out_file_format', - type=str, - choices=[file_format.value for file_format in file_adapters.FileFormat], - help='File format to convert the dataset to.', - required=True, - ) - parser.add_argument( - '--out_dir', - type=str, - help=( - 'Path where the converted dataset will be stored. Should include the' - ' config and version, e.g. `/data/dataset_name/config/1.2.3`. If not' - ' specified, the converted shards will be stored in the same' - ' directory as the input dataset.' - ), - default='', - required=False, - ) - parser.add_argument( - '--overwrite', - action='store_true', - help='Whether to overwrite the output directory if it already exists.', - ) - parser.add_argument( - '--use_beam', - action='store_true', - help='Use beam to convert the dataset.', - ) - parser.add_argument( - '--num_workers', - type=int, - default=8, - help=( - 'Number of workers to use when not using Beam. If `--use_beam` is' - ' set, this flag is ignored. If `--num_workers=1`, the conversion' - ' will be done sequentially.' - ), - ) +@dataclasses.dataclass(frozen=True, kw_only=True) +class CmdArgs: + """CLI arguments for converting a dataset from one file format to another format. + Attributes: + root_data_dir: Root data dir that contains all datasets. All datasets and + all their configs and versions that are in this folder will be converted. + dataset_dir: Path where the dataset to be converted is located. Converts all + configs and versions in this folder. + dataset_version_dir: Path where the dataset to be converted is located. + Should include config and version. Can also be a comma-separated list of + paths. If multiple paths are specified, `--out_dir` should not be + specified, since each dataset will be converted in the same directory as + the input dataset. + out_file_format: File format to convert the dataset to. + out_dir: Path where the converted dataset will be stored. Should include the + config and version, e.g. `/data/dataset_name/config/1.2.3`. If not + specified, the converted shards will be stored in the same directory as + the input dataset. + overwrite: Whether to overwrite the output directory if it already exists. + use_beam: Whether to use beam to convert the dataset. + num_workers: Number of workers to use when not using Beam. If `--use_beam` + is set, this flag is ignored. If `--num_workers=1`, the conversion will be + done sequentially. + """ -def register_subparser(parsers: argparse._SubParsersAction) -> None: - """Add subparser for `convert_format` command.""" - parser = parsers.add_parser( - 'convert_format', - help='Converts a dataset from one file format to another format.', + root_data_dir: epath.PathLike | None = None + dataset_dir: epath.PathLike | None = None + dataset_version_dir: list[str] = cli_utils.comma_separated_list_field() + # Need to override the default use of `Enum.name` for choice options. + out_file_format: str = simple_parsing.choice( + *(file_format.value for file_format in file_adapters.FileFormat) ) - add_parser_arguments(parser) + out_dir: epath.PathLike | None = None + overwrite: bool = False + use_beam: bool = False + num_workers: int = 8 - def _parse_dataset_version_dir( - dataset_version_dir: str | None, - ) -> Sequence[epath.Path] | None: - if not dataset_version_dir: - return None - return [epath.Path(path) for path in dataset_version_dir.split(',')] - - parser.set_defaults( - subparser_fn=lambda args: convert_format_utils.convert_dataset( - out_dir=args.out_dir if args.out_dir else None, - out_file_format=args.out_file_format, - dataset_dir=args.dataset_dir or None, - root_data_dir=args.root_data_dir or None, - dataset_version_dir=_parse_dataset_version_dir( - args.dataset_version_dir - ), - overwrite=args.overwrite, - use_beam=args.use_beam, - num_workers=args.num_workers, - ) - ) + def execute(self): + convert_format_utils.convert_dataset( + out_dir=self.out_dir, + out_file_format=self.out_file_format, + dataset_dir=self.dataset_dir, + root_data_dir=self.root_data_dir, + dataset_version_dir=self.dataset_version_dir, + overwrite=self.overwrite, + use_beam=self.use_beam, + num_workers=self.num_workers, + ) diff --git a/tensorflow_datasets/scripts/cli/croissant.py b/tensorflow_datasets/scripts/cli/croissant.py index dbda6440ad5..694fb8ce4b9 100644 --- a/tensorflow_datasets/scripts/cli/croissant.py +++ b/tensorflow_datasets/scripts/cli/croissant.py @@ -26,11 +26,9 @@ ``` """ -import argparse import dataclasses import functools import json -import typing from etils import epath import mlcroissant as mlc @@ -73,13 +71,7 @@ class CmdArgs(simple_parsing.helpers.FrozenSerializable): *(file_format.value for file_format in file_adapters.FileFormat), default=file_adapters.FileFormat.ARRAY_RECORD.value, ) - # Need to manually parse comma-separated list of values, see: - # https://github.com/lebrice/SimpleParsing/issues/142. - record_sets: list[str] = simple_parsing.field( - default_factory=list, - type=lambda record_sets_str: record_sets_str.split(','), - nargs='?', - ) + record_sets: list[str] = cli_utils.comma_separated_list_field() mapping: str | None = None download_dir: epath.Path | None = None publish_dir: epath.Path | None = None @@ -112,23 +104,8 @@ def record_set_ids(self) -> list[str]: self.dataset.metadata ) - -def register_subparser(parsers: argparse._SubParsersAction): - """Add subparser for `convert_format` command.""" - orig_parser_class = parsers._parser_class # pylint: disable=protected-access - try: - parsers._parser_class = simple_parsing.ArgumentParser # pylint: disable=protected-access - parser = parsers.add_parser( - 'build_croissant', - help='Prepares a croissant dataset', - ) - parser = typing.cast(simple_parsing.ArgumentParser, parser) - finally: - parsers._parser_class = orig_parser_class # pylint: disable=protected-access - parser.add_arguments(CmdArgs, dest='args') - parser.set_defaults( - subparser_fn=lambda args: prepare_croissant_builder(args.args) - ) + def execute(self): + prepare_croissant_builder(self) def prepare_croissant_builder(args: CmdArgs) -> None: diff --git a/tensorflow_datasets/scripts/cli/main.py b/tensorflow_datasets/scripts/cli/main.py index d20e6341866..7efef9da3ac 100644 --- a/tensorflow_datasets/scripts/cli/main.py +++ b/tensorflow_datasets/scripts/cli/main.py @@ -21,15 +21,14 @@ See: https://www.tensorflow.org/datasets/cli """ -import argparse +import dataclasses import logging as python_logging -from typing import List from absl import app from absl import flags from absl import logging -from absl.flags import argparse_flags - +from etils import eapp +import simple_parsing import tensorflow_datasets.public_api as tfds # Import commands @@ -37,35 +36,42 @@ from tensorflow_datasets.scripts.cli import convert_format from tensorflow_datasets.scripts.cli import croissant from tensorflow_datasets.scripts.cli import new -from tensorflow_datasets.scripts.utils import flag_utils FLAGS = flags.FLAGS -def _parse_flags(argv: List[str]) -> argparse.Namespace: - """Command lines flag parsing.""" - argv = flag_utils.normalize_flags(argv) # See b/174043007 for context. - - parser = argparse_flags.ArgumentParser( - description='Tensorflow Datasets CLI tool', - allow_abbrev=False, +@dataclasses.dataclass(frozen=True, kw_only=True) +class CmdArgs: + """TFDS CLI.""" + + command: ( + build.CmdArgs | convert_format.CmdArgs | croissant.CmdArgs | new.CmdArgs + ) = simple_parsing.subparsers( + subcommands={ + 'build': build.CmdArgs, + 'convert_format': convert_format.CmdArgs, + 'build_croissant': croissant.CmdArgs, + 'new': new.CmdArgs, + }, + default=None, ) - parser.add_argument( - '--version', + # pyformat: disable + version: bool = simple_parsing.field( + # pyformat: enable action='version', version='TensorFlow Datasets: ' + tfds.__version__, + help='Print version.', ) - parser.set_defaults(subparser_fn=lambda _: parser.print_help()) - # Register sub-commands - subparser = parser.add_subparsers(title='command') - build.register_subparser(subparser) - new.register_subparser(subparser) - convert_format.register_subparser(subparser) - croissant.register_subparser(subparser) - return parser.parse_args(argv[1:]) -def main(args: argparse.Namespace) -> None: +_parse_flags = eapp.make_flags_parser( + CmdArgs, + description='Tensorflow Datasets CLI tool', + allow_abbrev=False, +) + + +def main(args: CmdArgs) -> None: # From the CLI, all datasets are visible tfds.core.visibility.set_availables([ @@ -97,7 +103,10 @@ def main(args: argparse.Namespace) -> None: python_handler.setStream(new_stream) # Launch the subcommand defined in the subparser (or default to print help) - args.subparser_fn(args) + if args.command: + args.command.execute() + else: + _parse_flags(['', '--help']) def launch_cli() -> None: diff --git a/tensorflow_datasets/scripts/cli/main_test.py b/tensorflow_datasets/scripts/cli/main_test.py index 18e1c4e0d52..13e6d922623 100644 --- a/tensorflow_datasets/scripts/cli/main_test.py +++ b/tensorflow_datasets/scripts/cli/main_test.py @@ -28,4 +28,4 @@ def _check_exit(status=0, message=None): # Argparse call `sys.exit(0)` when `--version` is passed. with mock.patch('sys.exit', _check_exit): version_flag = '--version' - main.main(main._parse_flags(['', version_flag])) + main.main(main._parse_flags([version_flag])) diff --git a/tensorflow_datasets/scripts/cli/new.py b/tensorflow_datasets/scripts/cli/new.py index 908cbecc2be..ad1b09908ba 100644 --- a/tensorflow_datasets/scripts/cli/new.py +++ b/tensorflow_datasets/scripts/cli/new.py @@ -15,13 +15,14 @@ """`tfds new` command.""" -import argparse +import dataclasses import os import pathlib import subprocess import textwrap from typing import Optional +import simple_parsing from tensorflow_datasets.core import constants from tensorflow_datasets.core import dataset_metadata from tensorflow_datasets.core import naming @@ -30,43 +31,34 @@ from tensorflow_datasets.scripts.cli import cli_utils as utils -def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access - """Add subparser for `new` command.""" - new_parser = parsers.add_parser( - 'new', help='Creates a new dataset directory from the template.' - ) - new_parser.add_argument( - 'dataset_name', # Positional argument - type=str, - help='Name of the dataset to be created (in snake_case)', - ) - new_parser.add_argument( - '--data_format', # Optional argument - type=str, - default=builder_templates.STANDARD, +@dataclasses.dataclass(frozen=True, kw_only=True) +class CmdArgs: + """Creates a new dataset directory from the template. + + Attributes: + dataset_name: Name of the dataset to be created (in snake_case). + data_format: Optional format of the input data, which is used to generate a + format-specific template. + dir: Path where the dataset directory will be created. Defaults to current + directory. + """ + + dataset_name: str = simple_parsing.field(positional=True) + data_format: str = simple_parsing.field( choices=[ builder_templates.STANDARD, builder_templates.CONLL, builder_templates.CONLLU, ], - help=( - 'Optional format of the input data, which is used to generate a ' - 'format-specific template.' - ), - ) - new_parser.add_argument( - '--dir', - type=pathlib.Path, - default=pathlib.Path.cwd(), - help=( - 'Path where the dataset directory will be created. ' - 'Defaults to current directory.' - ), + default=builder_templates.STANDARD, ) - new_parser.set_defaults(subparser_fn=_create_dataset_files) + dir: pathlib.Path = simple_parsing.field(default=pathlib.Path.cwd()) + + def execute(self): + _create_dataset_files(self) -def _create_dataset_files(args: argparse.Namespace) -> None: +def _create_dataset_files(args: CmdArgs) -> None: """Creates the dataset directory. Executed by `tfds new `.""" if not naming.is_valid_dataset_and_class_name(args.dataset_name): raise ValueError( diff --git a/tensorflow_datasets/scripts/cli/new_test.py b/tensorflow_datasets/scripts/cli/new_test.py index b612c4b90a4..c76bedece0b 100644 --- a/tensorflow_datasets/scripts/cli/new_test.py +++ b/tensorflow_datasets/scripts/cli/new_test.py @@ -32,7 +32,10 @@ def test_new_without_args(capsys): _run_cli('new') captured = capsys.readouterr() - assert 'the following arguments are required: dataset_name' in captured.err + assert ( + 'the following arguments are required: args.command.dataset_name' + in captured.err + ) def test_new_invalid_name(): diff --git a/tensorflow_datasets/scripts/convert_format.py b/tensorflow_datasets/scripts/convert_format.py index 314bbadf0bd..246d95378ae 100644 --- a/tensorflow_datasets/scripts/convert_format.py +++ b/tensorflow_datasets/scripts/convert_format.py @@ -33,95 +33,24 @@ """ from absl import app -from absl import flags -from tensorflow_datasets.core import file_adapters +from etils import eapp +from tensorflow_datasets.scripts.cli import convert_format as convert_format_cli from tensorflow_datasets.scripts.cli import convert_format_utils -_ROOT_DATA_DIR = flags.DEFINE_string( - 'root_data_dir', - required=False, - help=( - 'Root data dir that contains all datasets. All datasets and all their' - ' configs and versions that are in this folder will be converted.' - ), - default=None, -) -_DATASET_DIR = flags.DEFINE_string( - 'dataset_dir', - required=False, - help=( - 'Path where the dataset to be converted is located. Converts all' - ' configs and versions in this folder.' - ), - default=None, -) -_DATASET_VERSION_DIR = flags.DEFINE_list( - 'dataset_version_dir', - required=False, - help=( - 'Path where the dataset to be converted is located. Should include' - ' config and version. Can also be a comma-separated list of paths. If' - ' multiple paths are specified, `--out_dir` should not be specified,' - ' since each dataset will be converted in the same directory as the' - ' input dataset.' - ), - default=None, -) - -_OUT_FILE_FORMAT = flags.DEFINE_enum_class( - 'out_file_format', - enum_class=file_adapters.FileFormat, - required=True, - help='File format to convert the dataset to.', - default=None, -) - -_OUT_DIR = flags.DEFINE_string( - 'out_dir', - required=False, - help=( - 'Path where the converted dataset will be stored. Should include the' - ' config and version, e.g. `/data/dataset_name/config/1.2.3`. If not' - ' specified, the converted shards will be stored in the same directory' - ' as the input dataset.' - ), - default=None, -) - -_USE_BEAM = flags.DEFINE_bool( - 'use_beam', - default=False, - help='Whether to use beam to convert the dataset.', -) - -_NUM_WORKERS = flags.DEFINE_integer( - 'num_workers', - default=8, - help='Number of workers to use if `use_beam` is `False`.', -) - - -_OVERWRITE = flags.DEFINE_bool( - 'overwrite', - default=False, - help='Whether to overwrite the output folder.', -) - - -def main(_): +def main(args: convert_format_cli.CmdArgs): convert_format_utils.convert_dataset( - root_data_dir=_ROOT_DATA_DIR.value, - dataset_dir=_DATASET_DIR.value, - dataset_version_dir=_DATASET_VERSION_DIR.value, - out_file_format=_OUT_FILE_FORMAT.value, - out_dir=_OUT_DIR.value, - use_beam=_USE_BEAM.value, - overwrite=_OVERWRITE.value, - num_workers=_NUM_WORKERS.value, + out_dir=args.out_dir, + out_file_format=args.out_file_format, + root_data_dir=args.root_data_dir, + dataset_dir=args.dataset_dir, + dataset_version_dir=args.dataset_version_dir, + overwrite=args.overwrite, + use_beam=args.use_beam, + num_workers=args.num_workers, ) if __name__ == '__main__': - app.run(main) + app.run(main, flags_parser=eapp.make_flags_parser(convert_format_cli.CmdArgs)) diff --git a/tensorflow_datasets/scripts/download_and_prepare.py b/tensorflow_datasets/scripts/download_and_prepare.py index cb563689105..241da2b536c 100644 --- a/tensorflow_datasets/scripts/download_and_prepare.py +++ b/tensorflow_datasets/scripts/download_and_prepare.py @@ -15,43 +15,56 @@ r"""Wrapper around `tfds build`.""" -import argparse +import dataclasses from typing import List from absl import app -from absl import flags from absl import logging - +import simple_parsing from tensorflow_datasets.scripts.cli import main as main_cli -module_import = flags.DEFINE_string('module_import', None, '`--imports` flag.') -dataset = flags.DEFINE_string('dataset', None, 'singleton `--datasets` flag.') -builder_config_id = flags.DEFINE_integer( - 'builder_config_id', None, '`--config_idx` flag' -) +@dataclasses.dataclass(frozen=True, kw_only=True) +class CmdArgs: + """CLI arguments for downloading and preparing datasets. + + Attributes: + module_import: `--imports` flag. + dataset: singleton `--datasets` flag. + builder_config_id: `--config_idx` flag. + """ + module_import: str | None = None + dataset: str | None = None + builder_config_id: int | None = None -def _parse_flags(argv: List[str]) -> argparse.Namespace: +def _parse_flags(argv: List[str]) -> main_cli.CmdArgs: """Command lines flag parsing.""" - return main_cli._parse_flags([argv[0], 'build'] + argv[1:]) # pylint: disable=protected-access + parser = simple_parsing.ArgumentParser() + parser.add_arguments(CmdArgs, dest='args') + namespace, build_argv = parser.parse_known_args(argv[1:]) + args = namespace.args + + # Construct CLI arguments for build command + build_argv = [argv[0], 'build'] + build_argv + if args.module_import: + build_argv += ['--imports', args.module_import] + if args.dataset: + build_argv += ['--datasets', args.dataset] + if args.builder_config_id is not None: + build_argv += ['--config_idx', args.builder_config_id] + return main_cli._parse_flags(build_argv) # pylint: disable=protected-access _display_warning = True -def main(args: argparse.Namespace) -> None: +def main(args: main_cli.CmdArgs) -> None: if _display_warning: logging.warning( '***`tfds build` should be used instead of `download_and_prepare`.***' ) - if module_import.value: - args.imports = module_import.value - if dataset.value: - args.datasets = [dataset.value] - if builder_config_id.value is not None: - args.config_idx = builder_config_id.value main_cli.main(args)