|
| 1 | +"""Tapify module, which can initialize a class or run a function by parsing arguments from the command line.""" |
| 2 | +from inspect import signature, Parameter |
| 3 | +from typing import Any, Callable, List, Optional, TypeVar, Union |
| 4 | + |
| 5 | +from docstring_parser import parse |
| 6 | + |
| 7 | +from tap import Tap |
| 8 | + |
| 9 | +InputType = TypeVar('InputType') |
| 10 | +OutputType = TypeVar('OutputType') |
| 11 | + |
| 12 | + |
| 13 | +def tapify(class_or_function: Union[Callable[[InputType], OutputType], OutputType], |
| 14 | + args: Optional[List[str]] = None, |
| 15 | + known_only: bool = False, |
| 16 | + **func_kwargs) -> OutputType: |
| 17 | + """Tapify initializes a class or runs a function by parsing arguments from the command line. |
| 18 | +
|
| 19 | + :param class_or_function: The class or function to run with the provided arguments. |
| 20 | + :param args: Arguments to parse. If None, arguments are parsed from the command line. |
| 21 | + :param known_only: If true, ignores extra arguments and only parses known arguments. |
| 22 | + :param func_kwargs: Additional keyword arguments for the function. These act as default values when |
| 23 | + parsing the command line arguments and overwrite the function defaults but |
| 24 | + are overwritten by the parsed command line arguments. |
| 25 | + """ |
| 26 | + # Get signature from class or function |
| 27 | + sig = signature(class_or_function) |
| 28 | + |
| 29 | + # Parse class or function docstring in one line |
| 30 | + if isinstance(class_or_function, type) and class_or_function.__init__.__doc__ is not None: |
| 31 | + doc = class_or_function.__init__.__doc__ |
| 32 | + else: |
| 33 | + doc = class_or_function.__doc__ |
| 34 | + |
| 35 | + # Parse docstring |
| 36 | + docstring = parse(doc) |
| 37 | + |
| 38 | + # Get the description of each argument in the class init or function |
| 39 | + param_to_description = {param.arg_name: param.description for param in docstring.params} |
| 40 | + |
| 41 | + # Create a Tap object |
| 42 | + tap = Tap(description='\n'.join(filter(None, (docstring.short_description, docstring.long_description)))) |
| 43 | + |
| 44 | + # Add arguments of class init or function to the Tap object |
| 45 | + for param_name, param in sig.parameters.items(): |
| 46 | + tap_kwargs = {} |
| 47 | + |
| 48 | + # Get type of the argument |
| 49 | + if param.annotation != Parameter.empty: |
| 50 | + # Any type defaults to str (needed for dataclasses where all non-default attributes must have a type) |
| 51 | + if param.annotation is Any: |
| 52 | + tap._annotations[param.name] = str |
| 53 | + # Otherwise, get the type of the argument |
| 54 | + else: |
| 55 | + tap._annotations[param.name] = param.annotation |
| 56 | + |
| 57 | + # Get the default or required of the argument |
| 58 | + if param.name in func_kwargs: |
| 59 | + tap_kwargs['default'] = func_kwargs[param.name] |
| 60 | + del func_kwargs[param.name] |
| 61 | + elif param.default != Parameter.empty: |
| 62 | + tap_kwargs['default'] = param.default |
| 63 | + else: |
| 64 | + tap_kwargs['required'] = True |
| 65 | + |
| 66 | + # Get the help string of the argument |
| 67 | + if param.name in param_to_description: |
| 68 | + tap.class_variables[param.name] = {'comment': param_to_description[param.name]} |
| 69 | + |
| 70 | + # Add the argument to the Tap object |
| 71 | + tap._add_argument(f'--{param_name}', **tap_kwargs) |
| 72 | + |
| 73 | + # If any func_kwargs remain, they are not used in the function, so raise an error |
| 74 | + if func_kwargs and not known_only: |
| 75 | + raise ValueError(f'Unknown keyword arguments: {func_kwargs}') |
| 76 | + |
| 77 | + # Parse command line arguments |
| 78 | + args = tap.parse_args( |
| 79 | + args=args, |
| 80 | + known_only=known_only |
| 81 | + ) |
| 82 | + |
| 83 | + # Initialize the class or run the function with the parsed arguments |
| 84 | + return class_or_function(**args.as_dict()) |
0 commit comments