Skip to content

Commit 5198e0b

Browse files
committed
Boxed types in Option now supported as well as external types such as Path
1 parent 4b89fb9 commit 5198e0b

File tree

3 files changed

+34
-34
lines changed

3 files changed

+34
-34
lines changed

tap/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ['__version__']
22

33
# major, minor, patch
4-
version_info = 1, 6, 1
4+
version_info = 1, 6, 2
55

66
# Nice string for the version
77
__version__ = '.'.join(map(str, version_info))

tap/tap.py

+10-17
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from warnings import warn
99
from types import MethodType
1010
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
11-
from typing_inspect import is_literal_type, get_args, is_union_type
11+
from typing_inspect import is_literal_type, get_args
1212

1313
from tap.utils import (
1414
get_class_variables,
@@ -166,6 +166,15 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
166166

167167
# If type is not explicitly provided, set it if it's one of our supported default types
168168
if 'type' not in kwargs:
169+
170+
# Unbox Optional[type] and set var_type = type
171+
if get_origin(var_type) in OPTIONAL_TYPES:
172+
var_args = get_args(var_type)
173+
174+
if len(var_args) > 0:
175+
var_type = get_args(var_type)[0]
176+
explicit_bool = True
177+
169178
# First check whether it is a literal type or a boxed literal type
170179
if is_literal_type(var_type):
171180
var_type, kwargs['choices'] = get_literals(var_type, variable)
@@ -193,22 +202,6 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
193202
kwargs['nargs'] = len(types)
194203

195204
var_type = TupleTypeEnforcer(types=types, loop=loop)
196-
# To identify an Optional type, check if it's a union of a None and something else
197-
elif (
198-
is_union_type(var_type)
199-
and len(get_args(var_type)) == 2
200-
and isinstance(None, get_args(var_type)[1])
201-
and is_literal_type(get_args(var_type)[0])
202-
):
203-
var_type, kwargs['choices'] = get_literals(get_args(var_type)[0], variable)
204-
205-
# Unbox Optional[type] and set var_type = type
206-
if get_origin(var_type) in OPTIONAL_TYPES:
207-
var_args = get_args(var_type)
208-
209-
if len(var_args) > 0:
210-
var_type = get_args(var_type)[0]
211-
explicit_bool = True
212205

213206
if get_origin(var_type) in BOXED_TYPES:
214207
# If List or Set type, set nargs

tests/test_integration.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ class NestedOptionalTypesTap(Tap):
175175
tuple_bool: Optional[Tuple[bool]]
176176
tuple_int: Optional[Tuple[int]]
177177
tuple_str: Optional[Tuple[str]]
178-
# tuple_pair: Optional[Tuple[bool, str, int]]
179-
# tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]]
180-
# tuple_arbitrary_len_int: Optional[Tuple[int, ...]]
181-
# tuple_arbitrary_len_str: Optional[Tuple[str, ...]]
178+
tuple_pair: Optional[Tuple[bool, str, int]]
179+
tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]]
180+
tuple_arbitrary_len_int: Optional[Tuple[int, ...]]
181+
tuple_arbitrary_len_str: Optional[Tuple[str, ...]]
182182

183183

184184
class NestedOptionalTypeTests(TestCase):
@@ -208,10 +208,10 @@ def test_nested_optional_types(self):
208208
'--tuple_bool', *stringify(tuple_bool),
209209
'--tuple_int', *stringify(tuple_int),
210210
'--tuple_str', *stringify(tuple_str),
211-
# '--tuple_pair', *stringify(tuple_pair),
212-
# '--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool),
213-
# '--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int),
214-
# '--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str),
211+
'--tuple_pair', *stringify(tuple_pair),
212+
'--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool),
213+
'--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int),
214+
'--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str),
215215
])
216216

217217
self.assertEqual(args.list_bool, list_bool)
@@ -225,10 +225,10 @@ def test_nested_optional_types(self):
225225
self.assertEqual(args.tuple_bool, tuple_bool)
226226
self.assertEqual(args.tuple_int, tuple_int)
227227
self.assertEqual(args.tuple_str, tuple_str)
228-
# self.assertEqual(args.tuple_pair, tuple_pair)
229-
# self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool)
230-
# self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int)
231-
# self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str)
228+
self.assertEqual(args.tuple_pair, tuple_pair)
229+
self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool)
230+
self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int)
231+
self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str)
232232

233233

234234
class ComplexTypeTap(Tap):
@@ -261,6 +261,7 @@ def test_complex_types(self):
261261
self.assertEqual(args.set_path, set_path)
262262
self.assertEqual(args.tuple_path, tuple_path)
263263

264+
264265
class Person:
265266
def __init__(self, name: str):
266267
self.name = name
@@ -631,26 +632,32 @@ def configure(self) -> None:
631632
def test_complex_type(self) -> None:
632633
class AddArgumentComplexTypeTap(IntegrationDefaultTap):
633634
arg_person: Person = Person('tap')
634-
# arg_person_required: Person # TODO
635+
arg_person_required: Person
635636
arg_person_untyped = Person('tap untyped')
636637

637-
# TODO: assert a crash if any complex types are not explicitly added in add_argument
638638
def configure(self) -> None:
639639
self.add_argument('--arg_person', type=Person)
640-
# self.add_argument('--arg_person_required', type=Person) # TODO
640+
self.add_argument('--arg_person_required', type=Person)
641641
self.add_argument('--arg_person_untyped', type=Person)
642642

643-
args = AddArgumentComplexTypeTap().parse_args([])
643+
arg_person_required = Person("hello, it's me")
644+
645+
args = AddArgumentComplexTypeTap().parse_args([
646+
'--arg_person_required', arg_person_required.name,
647+
])
644648
self.assertEqual(args.arg_person, Person('tap'))
649+
self.assertEqual(args.arg_person_required, arg_person_required)
645650
self.assertEqual(args.arg_person_untyped, Person('tap untyped'))
646651

647652
arg_person = Person('hi there')
648653
arg_person_untyped = Person('heyyyy')
649654
args = AddArgumentComplexTypeTap().parse_args([
650655
'--arg_person', arg_person.name,
656+
'--arg_person_required', arg_person_required.name,
651657
'--arg_person_untyped', arg_person_untyped.name
652658
])
653659
self.assertEqual(args.arg_person, arg_person)
660+
self.assertEqual(args.arg_person_required, arg_person_required)
654661
self.assertEqual(args.arg_person_untyped, arg_person_untyped)
655662

656663
def test_repeat_default(self) -> None:

0 commit comments

Comments
 (0)