Skip to content

Commit 9845d1b

Browse files
committed
Adding Union type tests, raising an error when a type function is not provided for union types, and fixing exception types in tests
1 parent ea8254d commit 9845d1b

File tree

6 files changed

+219
-20
lines changed

6 files changed

+219
-20
lines changed

.github/workflows/code-coverage.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ jobs:
66
run:
77
runs-on: ubuntu-latest
88
env:
9-
PYTHON: '3.9'
9+
PYTHON: '3.10'
1010
steps:
1111
- uses: actions/checkout@main
1212
- name: Setup Python
1313
uses: actions/setup-python@main
1414
with:
15-
python-version: 3.9
15+
python-version: 3.10
1616
- name: Generate coverage report
1717
run: |
1818
git config --global user.email "[email protected]"

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
matrix:
1818
os: [ubuntu-latest, macos-latest, windows-latest]
19-
python-version: [3.6, 3.7, 3.8, 3.9]
19+
python-version: [3.6, 3.7, 3.8, 3.9, 3.10]
2020

2121
steps:
2222
- uses: actions/checkout@v2

tap/tap.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import time
1111
from types import MethodType
1212
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
13-
from typing_inspect import is_literal_type, get_args
13+
from typing_inspect import is_literal_type
1414
from warnings import warn
1515

1616
from tap.utils import (
1717
get_class_variables,
18+
get_args,
1819
get_argument_name,
1920
get_dest,
2021
get_origin,
@@ -37,7 +38,8 @@
3738
# Constants
3839
EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple()
3940
BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple}
40-
OPTIONAL_TYPES = {Optional, Union} | ({UnionType} if sys.version_info >= (3, 10) else set())
41+
UNION_TYPES = {Union} | ({UnionType} if sys.version_info >= (3, 10) else set())
42+
OPTIONAL_TYPES = {Optional} | UNION_TYPES
4143
BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES
4244

4345

@@ -172,11 +174,26 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
172174

173175
# If type is not explicitly provided, set it if it's one of our supported default types
174176
if 'type' not in kwargs:
175-
176177
# Unbox Union[type] (Optional[type]) and set var_type = type
177178
if get_origin(var_type) in OPTIONAL_TYPES:
178179
var_args = get_args(var_type)
179180

181+
# If type is Union or Optional without inner types, set type to equivalent of Optional[str]
182+
if len(var_args) == 0:
183+
var_args = (str, type(None))
184+
185+
# Raise error if type function is not explicitly provided for Union types (not including Optionals)
186+
if get_origin(var_type) in UNION_TYPES and not (len(var_args) == 2 and var_args[1] == type(None)):
187+
raise ArgumentTypeError(
188+
'For Union types, you must include an explicit type function in the configure method. '
189+
'For example,\n\n'
190+
'class Args(Tap):\n'
191+
' arg: Union[int, float]\n'
192+
'\n'
193+
' def configure(self) -> None:\n'
194+
' self.add_argument("--arg", type=lambda x: float(x) if "." in x else int(x))'
195+
)
196+
180197
if len(var_args) > 0:
181198
var_type = var_args[0]
182199

@@ -242,7 +259,7 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
242259
kwargs['type'] = boolean_type
243260
kwargs['choices'] = [True, False] # this makes the help message more helpful
244261
else:
245-
action_cond = "true" if kwargs.get("required", False) or not kwargs["default"] else "false"
262+
action_cond = 'true' if kwargs.get('required', False) or not kwargs['default'] else 'false'
246263
kwargs['action'] = kwargs.get('action', f'store_{action_cond}')
247264
elif kwargs.get('action') not in {'count', 'append_const'}:
248265
kwargs['type'] = var_type

tap/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
Union,
2525
)
2626
from typing_extensions import Literal
27-
from typing_inspect import get_args, get_origin as typing_inspect_get_origin
27+
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
2828

2929
if sys.version_info >= (3, 10):
3030
from types import UnionType
@@ -487,3 +487,12 @@ def get_origin(tp: Any) -> Any:
487487
origin = UnionType
488488

489489
return origin
490+
491+
492+
# TODO: remove this once typing_insepct.get_args is fixed for Python 3.10 union types
493+
def get_args(tp: Any) -> Tuple[type, ...]:
494+
"""Same as typing_inspect.get_args but fixes Python 3.10 union types."""
495+
if sys.version_info >= (3, 10) and isinstance(tp, UnionType):
496+
return tp.__args__
497+
498+
return typing_inspect_get_args(tp)

tests/test_integration.py

+183-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from argparse import ArgumentTypeError
12
from copy import deepcopy
23
import os
34
from pathlib import Path
45
import pickle
56
import sys
67
from tempfile import TemporaryDirectory
7-
from typing import Any, Iterable, List, Optional, Set, Tuple
8+
from typing import Any, Iterable, List, Optional, Set, Tuple, Union
89
from typing_extensions import Literal
910
import unittest
1011
from unittest import TestCase
@@ -556,6 +557,186 @@ def write(self, msg):
556557
LiteralCrashTap().parse_args(['--arg_lit', '123'])
557558

558559

560+
def convert_str_or_int(str_or_int: str) -> Union[str, int]:
561+
try:
562+
return int(str_or_int)
563+
except ValueError:
564+
return str_or_int
565+
566+
567+
def convert_person_or_str(person_or_str: str) -> Union[Person, str]:
568+
if person_or_str == person_or_str.title():
569+
return Person(person_or_str)
570+
571+
return person_or_str
572+
573+
574+
def convert_many_types(input_str: str) -> Union[int, float, Person, str]:
575+
try:
576+
return int(input_str)
577+
except ValueError:
578+
try:
579+
return float(input_str)
580+
except ValueError:
581+
if input_str == input_str.title():
582+
return Person(input_str)
583+
584+
return input_str
585+
586+
587+
# TODO: test crash if not specifying type function
588+
class UnionTypeTap(Tap):
589+
union_zero_required_arg: Union
590+
union_zero_default_arg: Union = 'hi'
591+
union_one_required_arg: Union[str]
592+
union_one_default_arg: Union[str] = 'there'
593+
union_two_required_arg: Union[str, int]
594+
union_two_default_int_arg: Union[str, int] = 5
595+
union_two_default_str_arg: Union[str, int] = 'year old'
596+
union_custom_required_arg: Union[Person, str]
597+
union_custom_required_flip_arg: Union[str, Person]
598+
union_custom_default_arg: Union[Person, str] = Person('Jesse')
599+
union_custom_default_flip_arg: Union[str, Person] = 'I want'
600+
union_none_required_arg: Union[int, None]
601+
union_none_required_flip_arg: Union[None, int]
602+
union_many_required_arg: Union[int, float, Person, str]
603+
union_many_default_arg: Union[int, float, Person, str] = 3.14
604+
605+
def configure(self) -> None:
606+
self.add_argument('--union_two_required_arg', type=convert_str_or_int)
607+
self.add_argument('--union_two_default_int_arg', type=convert_str_or_int)
608+
self.add_argument('--union_two_default_str_arg', type=convert_str_or_int)
609+
self.add_argument('--union_custom_required_arg', type=convert_person_or_str)
610+
self.add_argument('--union_custom_required_flip_arg', type=convert_person_or_str)
611+
self.add_argument('--union_custom_default_arg', type=convert_person_or_str)
612+
self.add_argument('--union_custom_default_flip_arg', type=convert_person_or_str)
613+
self.add_argument('--union_none_required_flip_arg', type=int)
614+
self.add_argument('--union_many_required_arg', type=convert_many_types)
615+
self.add_argument('--union_many_default_arg', type=convert_many_types)
616+
617+
618+
if sys.version_info >= (3, 10):
619+
class UnionType310Tap(Tap):
620+
union_two_required_arg: str | int
621+
union_two_default_int_arg: str | int = 10
622+
union_two_default_str_arg: str | int = 'pieces of pie for'
623+
union_custom_required_arg: Person | str
624+
union_custom_required_flip_arg: str | Person
625+
union_custom_default_arg: Person | str = Person('Kyle')
626+
union_custom_default_flip_arg: str | Person = 'making'
627+
union_none_required_arg: int | None
628+
union_none_required_flip_arg: None | int
629+
union_many_required_arg: int | float | Person | str
630+
union_many_default_arg: int | float | Person | str = 3.14 * 10 / 8
631+
632+
def configure(self) -> None:
633+
self.add_argument('--union_two_required_arg', type=convert_str_or_int)
634+
self.add_argument('--union_two_default_int_arg', type=convert_str_or_int)
635+
self.add_argument('--union_two_default_str_arg', type=convert_str_or_int)
636+
self.add_argument('--union_custom_required_arg', type=convert_person_or_str)
637+
self.add_argument('--union_custom_required_flip_arg', type=convert_person_or_str)
638+
self.add_argument('--union_custom_default_arg', type=convert_person_or_str)
639+
self.add_argument('--union_custom_default_flip_arg', type=convert_person_or_str)
640+
self.add_argument('--union_none_required_flip_arg', type=int)
641+
self.add_argument('--union_many_required_arg', type=convert_many_types)
642+
self.add_argument('--union_many_default_arg', type=convert_many_types)
643+
644+
645+
class UnionTypeTests(TestCase):
646+
def test_union_types(self):
647+
union_zero_required_arg = 'Kyle'
648+
union_one_required_arg = 'ate'
649+
union_two_required_arg = '2'
650+
union_custom_required_arg = 'many'
651+
union_custom_required_flip_arg = 'Jesse'
652+
union_none_required_arg = '1'
653+
union_none_required_flip_arg = '5'
654+
union_many_required_arg = 'still hungry'
655+
656+
args = UnionTypeTap().parse_args([
657+
'--union_zero_required_arg', union_zero_required_arg,
658+
'--union_one_required_arg', union_one_required_arg,
659+
'--union_two_required_arg', union_two_required_arg,
660+
'--union_custom_required_arg', union_custom_required_arg,
661+
'--union_custom_required_flip_arg', union_custom_required_flip_arg,
662+
'--union_none_required_arg', union_none_required_arg,
663+
'--union_none_required_flip_arg', union_none_required_flip_arg,
664+
'--union_many_required_arg', union_many_required_arg
665+
])
666+
667+
union_two_required_arg = int(union_two_required_arg)
668+
union_custom_required_flip_arg = Person(union_custom_required_flip_arg)
669+
union_none_required_arg = int(union_none_required_arg)
670+
union_none_required_flip_arg = int(union_none_required_flip_arg)
671+
672+
self.assertEqual(args.union_zero_required_arg, union_zero_required_arg)
673+
self.assertEqual(args.union_zero_default_arg, UnionTypeTap.union_zero_default_arg)
674+
self.assertEqual(args.union_one_required_arg, union_one_required_arg)
675+
self.assertEqual(args.union_one_default_arg, UnionTypeTap.union_one_default_arg)
676+
self.assertEqual(args.union_two_required_arg, union_two_required_arg)
677+
self.assertEqual(args.union_two_default_int_arg, UnionTypeTap.union_two_default_int_arg)
678+
self.assertEqual(args.union_two_default_str_arg, UnionTypeTap.union_two_default_str_arg)
679+
self.assertEqual(args.union_custom_required_arg, union_custom_required_arg)
680+
self.assertEqual(args.union_custom_required_flip_arg, union_custom_required_flip_arg)
681+
self.assertEqual(args.union_custom_default_arg, UnionTypeTap.union_custom_default_arg)
682+
self.assertEqual(args.union_custom_default_flip_arg, UnionTypeTap.union_custom_default_flip_arg)
683+
self.assertEqual(args.union_none_required_arg, union_none_required_arg)
684+
self.assertEqual(args.union_none_required_flip_arg, union_none_required_flip_arg)
685+
self.assertEqual(args.union_many_required_arg, union_many_required_arg)
686+
self.assertEqual(args.union_many_default_arg, UnionTypeTap.union_many_default_arg)
687+
688+
def test_union_missing_type_function(self):
689+
class UnionMissingTypeFunctionTap(Tap):
690+
arg: Union[int, float]
691+
692+
with self.assertRaises(ArgumentTypeError):
693+
UnionMissingTypeFunctionTap()
694+
695+
@unittest.skipIf(sys.version_info < (3, 10), 'Union type operator "|" introduced in Python 3.10')
696+
def test_union_types_310(self):
697+
union_two_required_arg = '1' # int
698+
union_custom_required_arg = 'hungry' # str
699+
union_custom_required_flip_arg = 'Loser' # Person
700+
union_none_required_arg = '8' # int
701+
union_none_required_flip_arg = '100' # int
702+
union_many_required_arg = '3.14' # float
703+
704+
args = UnionType310Tap().parse_args([
705+
'--union_two_required_arg', union_two_required_arg,
706+
'--union_custom_required_arg', union_custom_required_arg,
707+
'--union_custom_required_flip_arg', union_custom_required_flip_arg,
708+
'--union_none_required_arg', union_none_required_arg,
709+
'--union_none_required_flip_arg', union_none_required_flip_arg,
710+
'--union_many_required_arg', union_many_required_arg
711+
])
712+
713+
union_two_required_arg = int(union_two_required_arg)
714+
union_custom_required_flip_arg = Person(union_custom_required_flip_arg)
715+
union_none_required_arg = int(union_none_required_arg)
716+
union_none_required_flip_arg = int(union_none_required_flip_arg)
717+
union_many_required_arg = float(union_many_required_arg)
718+
719+
self.assertEqual(args.union_two_required_arg, union_two_required_arg)
720+
self.assertEqual(args.union_two_default_int_arg, UnionType310Tap.union_two_default_int_arg)
721+
self.assertEqual(args.union_two_default_str_arg, UnionType310Tap.union_two_default_str_arg)
722+
self.assertEqual(args.union_custom_required_arg, union_custom_required_arg)
723+
self.assertEqual(args.union_custom_required_flip_arg, union_custom_required_flip_arg)
724+
self.assertEqual(args.union_custom_default_arg, UnionType310Tap.union_custom_default_arg)
725+
self.assertEqual(args.union_custom_default_flip_arg, UnionType310Tap.union_custom_default_flip_arg)
726+
self.assertEqual(args.union_none_required_arg, union_none_required_arg)
727+
self.assertEqual(args.union_none_required_flip_arg, union_none_required_flip_arg)
728+
self.assertEqual(args.union_many_required_arg, union_many_required_arg)
729+
self.assertEqual(args.union_many_default_arg, UnionType310Tap.union_many_default_arg)
730+
731+
@unittest.skipIf(sys.version_info < (3, 10), 'Union type operator "|" introduced in Python 3.10')
732+
def test_union_missing_type_function_310(self):
733+
class UnionMissingTypeFunctionTap(Tap):
734+
arg: int | float
735+
736+
with self.assertRaises(ArgumentTypeError):
737+
UnionMissingTypeFunctionTap()
738+
739+
559740
class AddArgumentTests(TestCase):
560741
def setUp(self) -> None:
561742
# Suppress prints from SystemExit
@@ -1044,7 +1225,7 @@ def test_empty_tuple_fails(self):
10441225
class EmptyTupleTap(Tap):
10451226
tup: Tuple[()]
10461227

1047-
with self.assertRaises(ValueError):
1228+
with self.assertRaises(ArgumentTypeError):
10481229
EmptyTupleTap().parse_args([])
10491230

10501231
def test_tuple_non_tuple_default(self):
@@ -1380,14 +1561,5 @@ def test_pickle(self):
13801561
self.assertEqual(loaded_args.as_dict(), args.as_dict())
13811562

13821563

1383-
"""
1384-
- crash if default type not supported
1385-
- user specifying process_args
1386-
- test get reproducibility info
1387-
- test str?
1388-
- test comments
1389-
"""
1390-
1391-
13921564
if __name__ == '__main__':
13931565
unittest.main()

tests/test_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from argparse import ArgumentTypeError
12
from collections import OrderedDict
23
import json
34
import os
@@ -297,7 +298,7 @@ def test_get_literals_primitives(self) -> None:
297298
self.assertEqual([literal_f(str(p)) for p in prims], literals)
298299

299300
def test_get_literals_uniqueness(self) -> None:
300-
with self.assertRaises(ValueError):
301+
with self.assertRaises(ArgumentTypeError):
301302
get_literals(Literal['two', 2, '2'], 'number')
302303

303304
def test_get_literals_empty(self) -> None:

0 commit comments

Comments
 (0)