|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from functools import wraps |
4 |
| -from builtins import all as builtin_all, any as builtin_any |
| 3 | +from functools import wraps as _wraps |
| 4 | +from builtins import all as _builtin_all, any as _builtin_any |
5 | 5 |
|
6 |
| -from ..common._aliases import (UniqueAllResult, UniqueCountsResult, |
7 |
| - UniqueInverseResult, |
8 |
| - matrix_transpose as _aliases_matrix_transpose, |
| 6 | +from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, |
9 | 7 | vecdot as _aliases_vecdot)
|
10 | 8 | from .._internal import get_xp
|
11 | 9 |
|
|
86 | 84 |
|
87 | 85 |
|
88 | 86 | def _two_arg(f):
|
89 |
| - @wraps(f) |
| 87 | + @_wraps(f) |
90 | 88 | def _f(x1, x2, /, **kwargs):
|
91 | 89 | x1, x2 = _fix_promotion(x1, x2)
|
92 | 90 | return f(x1, x2, **kwargs)
|
@@ -509,7 +507,7 @@ def arange(start: Union[int, float],
|
509 | 507 | start, stop = 0, start
|
510 | 508 | if step > 0 and stop <= start or step < 0 and stop >= start:
|
511 | 509 | if dtype is None:
|
512 |
| - if builtin_all(isinstance(i, int) for i in [start, stop, step]): |
| 510 | + if _builtin_all(isinstance(i, int) for i in [start, stop, step]): |
513 | 511 | dtype = torch.int64
|
514 | 512 | else:
|
515 | 513 | dtype = torch.float32
|
@@ -601,6 +599,11 @@ def broadcast_arrays(*arrays: array) -> List[array]:
|
601 | 599 | shape = torch.broadcast_shapes(*[a.shape for a in arrays])
|
602 | 600 | return [torch.broadcast_to(a, shape) for a in arrays]
|
603 | 601 |
|
| 602 | +# Note that these named tuples aren't actually part of the standard namespace, |
| 603 | +# but I don't see any issue with exporting the names here regardless. |
| 604 | +from ..common._aliases import (UniqueAllResult, UniqueCountsResult, |
| 605 | + UniqueInverseResult) |
| 606 | + |
604 | 607 | # https://github.com/pytorch/pytorch/issues/70920
|
605 | 608 | def unique_all(x: array) -> UniqueAllResult:
|
606 | 609 | # torch.unique doesn't support returning indices.
|
@@ -665,7 +668,7 @@ def isdtype(
|
665 | 668 | for more details
|
666 | 669 | """
|
667 | 670 | if isinstance(kind, tuple) and _tuple:
|
668 |
| - return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) |
| 671 | + return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) |
669 | 672 | elif isinstance(kind, str):
|
670 | 673 | if kind == 'bool':
|
671 | 674 | return dtype == torch.bool
|
@@ -693,15 +696,19 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
|
693 | 696 | axis = 0
|
694 | 697 | return torch.index_select(x, axis, indices, **kwargs)
|
695 | 698 |
|
696 |
| -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', |
697 |
| - 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', |
698 |
| - 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', |
699 |
| - 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', |
700 |
| - 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', |
701 |
| - 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', |
702 |
| - 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', |
703 |
| - 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', |
704 |
| - 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', |
705 |
| - 'broadcast_arrays', 'unique_all', 'unique_counts', |
706 |
| - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', |
707 |
| - 'vecdot', 'tensordot', 'isdtype', 'take'] |
| 699 | +__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', |
| 700 | + 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', |
| 701 | + 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', |
| 702 | + 'equal', 'floor_divide', 'greater', 'greater_equal', 'less', |
| 703 | + 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', |
| 704 | + 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum', |
| 705 | + 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', |
| 706 | + 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', |
| 707 | + 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', |
| 708 | + 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', |
| 709 | + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', |
| 710 | + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', |
| 711 | + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', |
| 712 | + 'take'] |
| 713 | + |
| 714 | +_all_ignore = ['torch', 'get_xp'] |
0 commit comments