|
| 1 | +from argparse import ArgumentTypeError |
1 | 2 | from copy import deepcopy
|
2 | 3 | import os
|
3 | 4 | from pathlib import Path
|
4 | 5 | import pickle
|
5 | 6 | import sys
|
6 | 7 | from tempfile import TemporaryDirectory
|
7 |
| -from typing import Any, Iterable, List, Optional, Set, Tuple |
| 8 | +from typing import Any, Iterable, List, Optional, Set, Tuple, Union |
8 | 9 | from typing_extensions import Literal
|
9 | 10 | import unittest
|
10 | 11 | from unittest import TestCase
|
@@ -556,6 +557,186 @@ def write(self, msg):
|
556 | 557 | LiteralCrashTap().parse_args(['--arg_lit', '123'])
|
557 | 558 |
|
558 | 559 |
|
| 560 | +def convert_str_or_int(str_or_int: str) -> Union[str, int]: |
| 561 | + try: |
| 562 | + return int(str_or_int) |
| 563 | + except ValueError: |
| 564 | + return str_or_int |
| 565 | + |
| 566 | + |
| 567 | +def convert_person_or_str(person_or_str: str) -> Union[Person, str]: |
| 568 | + if person_or_str == person_or_str.title(): |
| 569 | + return Person(person_or_str) |
| 570 | + |
| 571 | + return person_or_str |
| 572 | + |
| 573 | + |
| 574 | +def convert_many_types(input_str: str) -> Union[int, float, Person, str]: |
| 575 | + try: |
| 576 | + return int(input_str) |
| 577 | + except ValueError: |
| 578 | + try: |
| 579 | + return float(input_str) |
| 580 | + except ValueError: |
| 581 | + if input_str == input_str.title(): |
| 582 | + return Person(input_str) |
| 583 | + |
| 584 | + return input_str |
| 585 | + |
| 586 | + |
| 587 | +# TODO: test crash if not specifying type function |
| 588 | +class UnionTypeTap(Tap): |
| 589 | + union_zero_required_arg: Union |
| 590 | + union_zero_default_arg: Union = 'hi' |
| 591 | + union_one_required_arg: Union[str] |
| 592 | + union_one_default_arg: Union[str] = 'there' |
| 593 | + union_two_required_arg: Union[str, int] |
| 594 | + union_two_default_int_arg: Union[str, int] = 5 |
| 595 | + union_two_default_str_arg: Union[str, int] = 'year old' |
| 596 | + union_custom_required_arg: Union[Person, str] |
| 597 | + union_custom_required_flip_arg: Union[str, Person] |
| 598 | + union_custom_default_arg: Union[Person, str] = Person('Jesse') |
| 599 | + union_custom_default_flip_arg: Union[str, Person] = 'I want' |
| 600 | + union_none_required_arg: Union[int, None] |
| 601 | + union_none_required_flip_arg: Union[None, int] |
| 602 | + union_many_required_arg: Union[int, float, Person, str] |
| 603 | + union_many_default_arg: Union[int, float, Person, str] = 3.14 |
| 604 | + |
| 605 | + def configure(self) -> None: |
| 606 | + self.add_argument('--union_two_required_arg', type=convert_str_or_int) |
| 607 | + self.add_argument('--union_two_default_int_arg', type=convert_str_or_int) |
| 608 | + self.add_argument('--union_two_default_str_arg', type=convert_str_or_int) |
| 609 | + self.add_argument('--union_custom_required_arg', type=convert_person_or_str) |
| 610 | + self.add_argument('--union_custom_required_flip_arg', type=convert_person_or_str) |
| 611 | + self.add_argument('--union_custom_default_arg', type=convert_person_or_str) |
| 612 | + self.add_argument('--union_custom_default_flip_arg', type=convert_person_or_str) |
| 613 | + self.add_argument('--union_none_required_flip_arg', type=int) |
| 614 | + self.add_argument('--union_many_required_arg', type=convert_many_types) |
| 615 | + self.add_argument('--union_many_default_arg', type=convert_many_types) |
| 616 | + |
| 617 | + |
| 618 | +if sys.version_info >= (3, 10): |
| 619 | + class UnionType310Tap(Tap): |
| 620 | + union_two_required_arg: str | int |
| 621 | + union_two_default_int_arg: str | int = 10 |
| 622 | + union_two_default_str_arg: str | int = 'pieces of pie for' |
| 623 | + union_custom_required_arg: Person | str |
| 624 | + union_custom_required_flip_arg: str | Person |
| 625 | + union_custom_default_arg: Person | str = Person('Kyle') |
| 626 | + union_custom_default_flip_arg: str | Person = 'making' |
| 627 | + union_none_required_arg: int | None |
| 628 | + union_none_required_flip_arg: None | int |
| 629 | + union_many_required_arg: int | float | Person | str |
| 630 | + union_many_default_arg: int | float | Person | str = 3.14 * 10 / 8 |
| 631 | + |
| 632 | + def configure(self) -> None: |
| 633 | + self.add_argument('--union_two_required_arg', type=convert_str_or_int) |
| 634 | + self.add_argument('--union_two_default_int_arg', type=convert_str_or_int) |
| 635 | + self.add_argument('--union_two_default_str_arg', type=convert_str_or_int) |
| 636 | + self.add_argument('--union_custom_required_arg', type=convert_person_or_str) |
| 637 | + self.add_argument('--union_custom_required_flip_arg', type=convert_person_or_str) |
| 638 | + self.add_argument('--union_custom_default_arg', type=convert_person_or_str) |
| 639 | + self.add_argument('--union_custom_default_flip_arg', type=convert_person_or_str) |
| 640 | + self.add_argument('--union_none_required_flip_arg', type=int) |
| 641 | + self.add_argument('--union_many_required_arg', type=convert_many_types) |
| 642 | + self.add_argument('--union_many_default_arg', type=convert_many_types) |
| 643 | + |
| 644 | + |
| 645 | +class UnionTypeTests(TestCase): |
| 646 | + def test_union_types(self): |
| 647 | + union_zero_required_arg = 'Kyle' |
| 648 | + union_one_required_arg = 'ate' |
| 649 | + union_two_required_arg = '2' |
| 650 | + union_custom_required_arg = 'many' |
| 651 | + union_custom_required_flip_arg = 'Jesse' |
| 652 | + union_none_required_arg = '1' |
| 653 | + union_none_required_flip_arg = '5' |
| 654 | + union_many_required_arg = 'still hungry' |
| 655 | + |
| 656 | + args = UnionTypeTap().parse_args([ |
| 657 | + '--union_zero_required_arg', union_zero_required_arg, |
| 658 | + '--union_one_required_arg', union_one_required_arg, |
| 659 | + '--union_two_required_arg', union_two_required_arg, |
| 660 | + '--union_custom_required_arg', union_custom_required_arg, |
| 661 | + '--union_custom_required_flip_arg', union_custom_required_flip_arg, |
| 662 | + '--union_none_required_arg', union_none_required_arg, |
| 663 | + '--union_none_required_flip_arg', union_none_required_flip_arg, |
| 664 | + '--union_many_required_arg', union_many_required_arg |
| 665 | + ]) |
| 666 | + |
| 667 | + union_two_required_arg = int(union_two_required_arg) |
| 668 | + union_custom_required_flip_arg = Person(union_custom_required_flip_arg) |
| 669 | + union_none_required_arg = int(union_none_required_arg) |
| 670 | + union_none_required_flip_arg = int(union_none_required_flip_arg) |
| 671 | + |
| 672 | + self.assertEqual(args.union_zero_required_arg, union_zero_required_arg) |
| 673 | + self.assertEqual(args.union_zero_default_arg, UnionTypeTap.union_zero_default_arg) |
| 674 | + self.assertEqual(args.union_one_required_arg, union_one_required_arg) |
| 675 | + self.assertEqual(args.union_one_default_arg, UnionTypeTap.union_one_default_arg) |
| 676 | + self.assertEqual(args.union_two_required_arg, union_two_required_arg) |
| 677 | + self.assertEqual(args.union_two_default_int_arg, UnionTypeTap.union_two_default_int_arg) |
| 678 | + self.assertEqual(args.union_two_default_str_arg, UnionTypeTap.union_two_default_str_arg) |
| 679 | + self.assertEqual(args.union_custom_required_arg, union_custom_required_arg) |
| 680 | + self.assertEqual(args.union_custom_required_flip_arg, union_custom_required_flip_arg) |
| 681 | + self.assertEqual(args.union_custom_default_arg, UnionTypeTap.union_custom_default_arg) |
| 682 | + self.assertEqual(args.union_custom_default_flip_arg, UnionTypeTap.union_custom_default_flip_arg) |
| 683 | + self.assertEqual(args.union_none_required_arg, union_none_required_arg) |
| 684 | + self.assertEqual(args.union_none_required_flip_arg, union_none_required_flip_arg) |
| 685 | + self.assertEqual(args.union_many_required_arg, union_many_required_arg) |
| 686 | + self.assertEqual(args.union_many_default_arg, UnionTypeTap.union_many_default_arg) |
| 687 | + |
| 688 | + def test_union_missing_type_function(self): |
| 689 | + class UnionMissingTypeFunctionTap(Tap): |
| 690 | + arg: Union[int, float] |
| 691 | + |
| 692 | + with self.assertRaises(ArgumentTypeError): |
| 693 | + UnionMissingTypeFunctionTap() |
| 694 | + |
| 695 | + @unittest.skipIf(sys.version_info < (3, 10), 'Union type operator "|" introduced in Python 3.10') |
| 696 | + def test_union_types_310(self): |
| 697 | + union_two_required_arg = '1' # int |
| 698 | + union_custom_required_arg = 'hungry' # str |
| 699 | + union_custom_required_flip_arg = 'Loser' # Person |
| 700 | + union_none_required_arg = '8' # int |
| 701 | + union_none_required_flip_arg = '100' # int |
| 702 | + union_many_required_arg = '3.14' # float |
| 703 | + |
| 704 | + args = UnionType310Tap().parse_args([ |
| 705 | + '--union_two_required_arg', union_two_required_arg, |
| 706 | + '--union_custom_required_arg', union_custom_required_arg, |
| 707 | + '--union_custom_required_flip_arg', union_custom_required_flip_arg, |
| 708 | + '--union_none_required_arg', union_none_required_arg, |
| 709 | + '--union_none_required_flip_arg', union_none_required_flip_arg, |
| 710 | + '--union_many_required_arg', union_many_required_arg |
| 711 | + ]) |
| 712 | + |
| 713 | + union_two_required_arg = int(union_two_required_arg) |
| 714 | + union_custom_required_flip_arg = Person(union_custom_required_flip_arg) |
| 715 | + union_none_required_arg = int(union_none_required_arg) |
| 716 | + union_none_required_flip_arg = int(union_none_required_flip_arg) |
| 717 | + union_many_required_arg = float(union_many_required_arg) |
| 718 | + |
| 719 | + self.assertEqual(args.union_two_required_arg, union_two_required_arg) |
| 720 | + self.assertEqual(args.union_two_default_int_arg, UnionType310Tap.union_two_default_int_arg) |
| 721 | + self.assertEqual(args.union_two_default_str_arg, UnionType310Tap.union_two_default_str_arg) |
| 722 | + self.assertEqual(args.union_custom_required_arg, union_custom_required_arg) |
| 723 | + self.assertEqual(args.union_custom_required_flip_arg, union_custom_required_flip_arg) |
| 724 | + self.assertEqual(args.union_custom_default_arg, UnionType310Tap.union_custom_default_arg) |
| 725 | + self.assertEqual(args.union_custom_default_flip_arg, UnionType310Tap.union_custom_default_flip_arg) |
| 726 | + self.assertEqual(args.union_none_required_arg, union_none_required_arg) |
| 727 | + self.assertEqual(args.union_none_required_flip_arg, union_none_required_flip_arg) |
| 728 | + self.assertEqual(args.union_many_required_arg, union_many_required_arg) |
| 729 | + self.assertEqual(args.union_many_default_arg, UnionType310Tap.union_many_default_arg) |
| 730 | + |
| 731 | + @unittest.skipIf(sys.version_info < (3, 10), 'Union type operator "|" introduced in Python 3.10') |
| 732 | + def test_union_missing_type_function_310(self): |
| 733 | + class UnionMissingTypeFunctionTap(Tap): |
| 734 | + arg: int | float |
| 735 | + |
| 736 | + with self.assertRaises(ArgumentTypeError): |
| 737 | + UnionMissingTypeFunctionTap() |
| 738 | + |
| 739 | + |
559 | 740 | class AddArgumentTests(TestCase):
|
560 | 741 | def setUp(self) -> None:
|
561 | 742 | # Suppress prints from SystemExit
|
@@ -1044,7 +1225,7 @@ def test_empty_tuple_fails(self):
|
1044 | 1225 | class EmptyTupleTap(Tap):
|
1045 | 1226 | tup: Tuple[()]
|
1046 | 1227 |
|
1047 |
| - with self.assertRaises(ValueError): |
| 1228 | + with self.assertRaises(ArgumentTypeError): |
1048 | 1229 | EmptyTupleTap().parse_args([])
|
1049 | 1230 |
|
1050 | 1231 | def test_tuple_non_tuple_default(self):
|
@@ -1380,14 +1561,5 @@ def test_pickle(self):
|
1380 | 1561 | self.assertEqual(loaded_args.as_dict(), args.as_dict())
|
1381 | 1562 |
|
1382 | 1563 |
|
1383 |
| -""" |
1384 |
| -- crash if default type not supported |
1385 |
| -- user specifying process_args |
1386 |
| -- test get reproducibility info |
1387 |
| -- test str? |
1388 |
| -- test comments |
1389 |
| -""" |
1390 |
| - |
1391 |
| - |
1392 | 1564 | if __name__ == '__main__':
|
1393 | 1565 | unittest.main()
|
0 commit comments