Skip to content

Commit b51d06c

Browse files
authored
Merge pull request #47 from swansonk14/complex-types
Complex types
2 parents 0dc69b3 + ef08776 commit b51d06c

File tree

4 files changed

+208
-54
lines changed

4 files changed

+208
-54
lines changed

tap/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ['__version__']
22

33
# major, minor, patch
4-
version_info = 1, 6, 1
4+
version_info = 1, 6, 2
55

66
# Nice string for the version
77
__version__ = '.'.join(map(str, version_info))

tap/tap.py

+36-37
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
from warnings import warn
99
from types import MethodType
1010
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
11-
from typing_inspect import is_literal_type, get_args, get_origin, is_union_type
11+
from typing_inspect import is_literal_type, get_args
1212

1313
from tap.utils import (
1414
get_class_variables,
1515
get_argument_name,
1616
get_git_root,
1717
get_dest,
1818
get_git_url,
19+
get_origin,
1920
has_git,
2021
has_uncommitted_changes,
2122
is_option_arg,
@@ -32,16 +33,10 @@
3233

3334
# Constants
3435
EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple()
36+
BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple}
37+
OPTIONAL_TYPES = {Optional, Union}
38+
BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES
3539

36-
SUPPORTED_DEFAULT_BASE_TYPES = {str, int, float, bool}
37-
SUPPORTED_DEFAULT_OPTIONAL_TYPES = {Optional, Optional[str], Optional[int], Optional[float], Optional[bool]}
38-
SUPPORTED_DEFAULT_LIST_TYPES = {List, List[str], List[int], List[float], List[bool]}
39-
SUPPORTED_DEFAULT_SET_TYPES = {Set, Set[str], Set[int], Set[float], Set[bool]}
40-
SUPPORTED_DEFAULT_COLLECTION_TYPES = SUPPORTED_DEFAULT_LIST_TYPES | SUPPORTED_DEFAULT_SET_TYPES | {Tuple}
41-
SUPPORTED_DEFAULT_BOXED_TYPES = SUPPORTED_DEFAULT_OPTIONAL_TYPES | SUPPORTED_DEFAULT_COLLECTION_TYPES
42-
SUPPORTED_DEFAULT_TYPES = set.union(SUPPORTED_DEFAULT_BASE_TYPES,
43-
SUPPORTED_DEFAULT_OPTIONAL_TYPES,
44-
SUPPORTED_DEFAULT_COLLECTION_TYPES)
4540

4641
TapType = TypeVar('TapType', bound='Tap')
4742

@@ -125,6 +120,9 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
125120
:param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
126121
:param kwargs: Keyword arguments.
127122
"""
123+
# Set explicit bool
124+
explicit_bool = self._explicit_bool
125+
128126
# Get variable name
129127
variable = get_argument_name(*name_or_flags)
130128

@@ -168,6 +166,21 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
168166

169167
# If type is not explicitly provided, set it if it's one of our supported default types
170168
if 'type' not in kwargs:
169+
170+
# Unbox Optional[type] and set var_type = type
171+
if get_origin(var_type) in OPTIONAL_TYPES:
172+
var_args = get_args(var_type)
173+
174+
if len(var_args) > 0:
175+
var_type = get_args(var_type)[0]
176+
177+
# If var_type is tuple as in Python 3.6, change to a typing type
178+
# (e.g., (typing.List, <class 'bool'>) ==> typing.List[bool])
179+
if isinstance(var_type, tuple):
180+
var_type = var_type[0][var_type[1:]]
181+
182+
explicit_bool = True
183+
171184
# First check whether it is a literal type or a boxed literal type
172185
if is_literal_type(var_type):
173186
var_type, kwargs['choices'] = get_literals(var_type, variable)
@@ -195,27 +208,10 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
195208
kwargs['nargs'] = len(types)
196209

197210
var_type = TupleTypeEnforcer(types=types, loop=loop)
198-
# To identify an Optional type, check if it's a union of a None and something else
199-
elif (
200-
is_union_type(var_type)
201-
and len(get_args(var_type)) == 2
202-
and isinstance(None, get_args(var_type)[1])
203-
and is_literal_type(get_args(var_type)[0])
204-
):
205-
var_type, kwargs['choices'] = get_literals(get_args(var_type)[0], variable)
206-
elif var_type not in SUPPORTED_DEFAULT_TYPES:
207-
is_required = kwargs.get('required', False)
208-
arg_params = 'required=True' if is_required else f'default={getattr(self, variable)}'
209-
raise ValueError(
210-
f'Variable "{variable}" has type "{var_type}" which is not supported by default.\n'
211-
f'Please explicitly add the argument to the parser by writing:\n\n'
212-
f'def configure(self) -> None:\n'
213-
f' self.add_argument("--{variable}", type=func, {arg_params})\n\n'
214-
f'where "func" maps from str to {var_type}.')
215-
216-
if var_type in SUPPORTED_DEFAULT_BOXED_TYPES:
211+
212+
if get_origin(var_type) in BOXED_TYPES:
217213
# If List or Set type, set nargs
218-
if (var_type in SUPPORTED_DEFAULT_COLLECTION_TYPES
214+
if (get_origin(var_type) in BOXED_COLLECTION_TYPES
219215
and kwargs.get('action') not in {'append', 'append_const'}):
220216
kwargs['nargs'] = kwargs.get('nargs', '*')
221217

@@ -228,13 +224,12 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
228224
else:
229225
var_type = arg_types[0]
230226

231-
# Handle the cases of Optional[bool], List[bool], Set[bool]
227+
# Handle the cases of List[bool], Set[bool], Tuple[bool]
232228
if var_type == bool:
233229
var_type = boolean_type
234-
235230
# If bool then set action, otherwise set type
236231
if var_type == bool:
237-
if self._explicit_bool:
232+
if explicit_bool:
238233
kwargs['type'] = boolean_type
239234
kwargs['choices'] = [True, False] # this makes the help message more helpful
240235
else:
@@ -404,10 +399,14 @@ def parse_args(self: TapType,
404399
if type(value) == list:
405400
var_type = get_origin(self._annotations[variable])
406401

407-
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9
408-
# https://github.com/ilevkivskyi/typing_inspect/issues/64
409-
# https://github.com/ilevkivskyi/typing_inspect/issues/65
410-
var_type = var_type if var_type is not None else self._annotations[variable]
402+
# Unpack nested boxed types such as Optional[List[int]]
403+
if var_type is Union:
404+
var_type = get_origin(get_args(self._annotations[variable])[0])
405+
406+
# If var_type is tuple as in Python 3.6, change to a typing type
407+
# (e.g., (typing.Tuple, <class 'bool'>) ==> typing.Tuple)
408+
if isinstance(var_type, tuple):
409+
var_type = var_type[0]
411410

412411
if var_type in (Set, set):
413412
value = set(value)

tap/utils.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
Union,
2525
)
2626
from typing_extensions import Literal
27-
from typing_inspect import get_args
27+
from typing_inspect import get_args, get_origin as typing_inspect_get_origin
2828

2929

3030
NO_CHANGES_STATUS = """nothing to commit, working tree clean"""
@@ -467,3 +467,16 @@ def enforce_reproducibility(saved_reproducibility_data: Optional[Dict[str, str]]
467467
if current_reproducibility_data['git_has_uncommitted_changes']:
468468
raise ValueError(f'{no_reproducibility_message}: Uncommitted changes '
469469
f'in current args.')
470+
471+
472+
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8 and 3.9
473+
# https://github.com/ilevkivskyi/typing_inspect/issues/64
474+
# https://github.com/ilevkivskyi/typing_inspect/issues/65
475+
def get_origin(tp: Any) -> Any:
476+
"""Same as typing_inspect.get_origin but fixes unparameterized generic types like Set."""
477+
origin = typing_inspect_get_origin(tp)
478+
479+
if origin is None:
480+
origin = tp
481+
482+
return origin

tests/test_integration.py

+157-15
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
from copy import deepcopy
22
import os
3+
from pathlib import Path
34
import sys
45
from tempfile import TemporaryDirectory
5-
from typing import Any, List, Optional, Set, Tuple
6+
from typing import Any, Iterable, List, Optional, Set, Tuple
67
from typing_extensions import Literal
78
import unittest
89
from unittest import TestCase
910

1011
from tap import Tap
1112

1213

14+
def stringify(arg_list: Iterable[Any]) -> List[str]:
15+
"""Converts an iterable of arguments of any type to a list of strings.
16+
17+
:param arg_list: An iterable of arguments of any type.
18+
:return: A list of the arguments as strings.
19+
"""
20+
return [str(arg) for arg in arg_list]
21+
22+
1323
class EdgeCaseTests(TestCase):
1424
def test_empty(self) -> None:
1525
class EmptyTap(Tap):
@@ -112,17 +122,144 @@ def test_both_assigned_okay(self):
112122
self.assertEqual(args.arg_list_str_required, ['hi', 'there'])
113123

114124

115-
class CrashesOnUnsupportedTypesTests(TestCase):
125+
# TODO: need to implement list[str] etc.
126+
# class ParameterizedStandardCollectionTap(Tap):
127+
# arg_list_str: list[str]
128+
# arg_list_int: list[int]
129+
# arg_list_int_default: list[int] = [1, 2, 5]
130+
# arg_set_float: set[float]
131+
# arg_set_str_default: set[str] = ['one', 'two', 'five']
132+
# arg_tuple_int: tuple[int, ...]
133+
# arg_tuple_float_default: tuple[float, float, float] = (1.0, 2.0, 5.0)
134+
# arg_tuple_str_override: tuple[str, str] = ('hi', 'there')
135+
# arg_optional_list_int: Optional[list[int]] = None
136+
137+
138+
# class ParameterizedStandardCollectionTests(TestCase):
139+
# @unittest.skipIf(sys.version_info < (3, 9), 'Parameterized standard collections (e.g., list[int]) introduced in Python 3.9')
140+
# def test_parameterized_standard_collection(self):
141+
# arg_list_str = ['a', 'b', 'pi']
142+
# arg_list_int = [-2, -5, 10]
143+
# arg_set_float = {3.54, 2.235}
144+
# arg_tuple_int = (-4, 5, 9, 103)
145+
# arg_tuple_str_override = ('why', 'so', 'many', 'tests?')
146+
# arg_optional_list_int = [5, 4, 3]
147+
148+
# args = ParameterizedStandardCollectionTap().parse_args([
149+
# '--arg_list_str', *arg_list_str,
150+
# '--arg_list_int', *[str(var) for var in arg_list_int],
151+
# '--arg_set_float', *[str(var) for var in arg_set_float],
152+
# '--arg_tuple_int', *[str(var) for var in arg_tuple_int],
153+
# '--arg_tuple_str_override', *arg_tuple_str_override,
154+
# '--arg_optional_list_int', *[str(var) for var in arg_optional_list_int]
155+
# ])
156+
157+
# self.assertEqual(args.arg_list_str, arg_list_str)
158+
# self.assertEqual(args.arg_list_int, arg_list_int)
159+
# self.assertEqual(args.arg_list_int_default, ParameterizedStandardCollectionTap.arg_list_int_default)
160+
# self.assertEqual(args.arg_set_float, arg_set_float)
161+
# self.assertEqual(args.arg_set_str_default, ParameterizedStandardCollectionTap.arg_set_str_default)
162+
# self.assertEqual(args.arg_tuple_int, arg_tuple_int)
163+
# self.assertEqual(args.arg_tuple_float_default, ParameterizedStandardCollectionTap.arg_tuple_float_default)
164+
# self.assertEqual(args.arg_tuple_str_override, arg_tuple_str_override)
165+
# self.assertEqual(args.arg_optional_list_int, arg_optional_list_int)
166+
167+
168+
class NestedOptionalTypesTap(Tap):
169+
list_bool: Optional[List[bool]]
170+
list_int: Optional[List[int]]
171+
list_str: Optional[List[str]]
172+
set_bool: Optional[Set[bool]]
173+
set_int: Optional[Set[int]]
174+
set_str: Optional[Set[str]]
175+
tuple_bool: Optional[Tuple[bool]]
176+
tuple_int: Optional[Tuple[int]]
177+
tuple_str: Optional[Tuple[str]]
178+
tuple_pair: Optional[Tuple[bool, str, int]]
179+
tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]]
180+
tuple_arbitrary_len_int: Optional[Tuple[int, ...]]
181+
tuple_arbitrary_len_str: Optional[Tuple[str, ...]]
182+
183+
184+
class NestedOptionalTypeTests(TestCase):
185+
186+
def test_nested_optional_types(self):
187+
list_bool = [True, False]
188+
list_int = [0, 1, 2]
189+
list_str = ['a', 'bee', 'cd', 'ee']
190+
set_bool = {True, False, True}
191+
set_int = {0, 1}
192+
set_str = {'a', 'bee', 'cd'}
193+
tuple_bool = (False,)
194+
tuple_int = (0,)
195+
tuple_str = ('a',)
196+
tuple_pair = (False, 'a', 1)
197+
tuple_arbitrary_len_bool = (True, False, False)
198+
tuple_arbitrary_len_int = (1, 2, 3, 4)
199+
tuple_arbitrary_len_str = ('a', 'b')
200+
201+
args = NestedOptionalTypesTap().parse_args([
202+
'--list_bool', *stringify(list_bool),
203+
'--list_int', *stringify(list_int),
204+
'--list_str', *stringify(list_str),
205+
'--set_bool', *stringify(set_bool),
206+
'--set_int', *stringify(set_int),
207+
'--set_str', *stringify(set_str),
208+
'--tuple_bool', *stringify(tuple_bool),
209+
'--tuple_int', *stringify(tuple_int),
210+
'--tuple_str', *stringify(tuple_str),
211+
'--tuple_pair', *stringify(tuple_pair),
212+
'--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool),
213+
'--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int),
214+
'--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str),
215+
])
116216

117-
def test_crashes_on_unsupported(self):
118-
# From PiDelport: https://github.com/swansonk14/typed-argument-parser/issues/27
119-
from pathlib import Path
217+
self.assertEqual(args.list_bool, list_bool)
218+
self.assertEqual(args.list_int, list_int)
219+
self.assertEqual(args.list_str, list_str)
220+
221+
self.assertEqual(args.set_bool, set_bool)
222+
self.assertEqual(args.set_int, set_int)
223+
self.assertEqual(args.set_str, set_str)
224+
225+
self.assertEqual(args.tuple_bool, tuple_bool)
226+
self.assertEqual(args.tuple_int, tuple_int)
227+
self.assertEqual(args.tuple_str, tuple_str)
228+
self.assertEqual(args.tuple_pair, tuple_pair)
229+
self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool)
230+
self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int)
231+
self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str)
232+
233+
234+
class ComplexTypeTap(Tap):
235+
path: Path
236+
optional_path: Optional[Path]
237+
list_path: List[Path]
238+
set_path: Set[Path]
239+
tuple_path: Tuple[Path, Path]
240+
241+
242+
class ComplexTypeTests(TestCase):
243+
def test_complex_types(self):
244+
path = Path('/path/to/file.txt')
245+
optional_path = Path('/path/to/optional/file.txt')
246+
list_path = [Path('/path/to/list/file1.txt'), Path('/path/to/list/file2.txt')]
247+
set_path = {Path('/path/to/set/file1.txt'), Path('/path/to/set/file2.txt')}
248+
tuple_path = (Path('/path/to/tuple/file1.txt'), Path('/path/to/tuple/file2.txt'))
249+
250+
args = ComplexTypeTap().parse_args([
251+
'--path', str(path),
252+
'--optional_path', str(optional_path),
253+
'--list_path', *[str(path) for path in list_path],
254+
'--set_path', *[str(path) for path in set_path],
255+
'--tuple_path', *[str(path) for path in tuple_path]
256+
])
120257

121-
class CrashingArgumentParser(Tap):
122-
some_path: Path = 'some_path'
123-
124-
with self.assertRaises(ValueError):
125-
CrashingArgumentParser().parse_args([])
258+
self.assertEqual(args.path, path)
259+
self.assertEqual(args.optional_path, optional_path)
260+
self.assertEqual(args.list_path, list_path)
261+
self.assertEqual(args.set_path, set_path)
262+
self.assertEqual(args.tuple_path, tuple_path)
126263

127264

128265
class Person:
@@ -312,7 +449,6 @@ def test_set_default_args(self) -> None:
312449
'--arg_list_bool', *arg_list_bool,
313450
'--arg_list_str_empty', *arg_list_str_empty,
314451
'--arg_list_literal', *arg_list_literal,
315-
316452
'--arg_set', *arg_set,
317453
'--arg_set_str', *arg_set_str,
318454
'--arg_set_int', *arg_set_int,
@@ -496,26 +632,32 @@ def configure(self) -> None:
496632
def test_complex_type(self) -> None:
497633
class AddArgumentComplexTypeTap(IntegrationDefaultTap):
498634
arg_person: Person = Person('tap')
499-
# arg_person_required: Person # TODO
635+
arg_person_required: Person
500636
arg_person_untyped = Person('tap untyped')
501637

502-
# TODO: assert a crash if any complex types are not explicitly added in add_argument
503638
def configure(self) -> None:
504639
self.add_argument('--arg_person', type=Person)
505-
# self.add_argument('--arg_person_required', type=Person) # TODO
640+
self.add_argument('--arg_person_required', type=Person)
506641
self.add_argument('--arg_person_untyped', type=Person)
507642

508-
args = AddArgumentComplexTypeTap().parse_args([])
643+
arg_person_required = Person("hello, it's me")
644+
645+
args = AddArgumentComplexTypeTap().parse_args([
646+
'--arg_person_required', arg_person_required.name,
647+
])
509648
self.assertEqual(args.arg_person, Person('tap'))
649+
self.assertEqual(args.arg_person_required, arg_person_required)
510650
self.assertEqual(args.arg_person_untyped, Person('tap untyped'))
511651

512652
arg_person = Person('hi there')
513653
arg_person_untyped = Person('heyyyy')
514654
args = AddArgumentComplexTypeTap().parse_args([
515655
'--arg_person', arg_person.name,
656+
'--arg_person_required', arg_person_required.name,
516657
'--arg_person_untyped', arg_person_untyped.name
517658
])
518659
self.assertEqual(args.arg_person, arg_person)
660+
self.assertEqual(args.arg_person_required, arg_person_required)
519661
self.assertEqual(args.arg_person_untyped, arg_person_untyped)
520662

521663
def test_repeat_default(self) -> None:

0 commit comments

Comments
 (0)