@@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True):
116
116
_py_scalars = (bool , int , float , complex )
117
117
118
118
119
- def result_type (* arrays_and_dtypes : Array | DType | complex ) -> DType :
119
+ def result_type (
120
+ * arrays_and_dtypes : Array | DType | bool | int | float | complex
121
+ ) -> DType :
120
122
num = len (arrays_and_dtypes )
121
123
122
124
if num == 0 :
@@ -550,10 +552,16 @@ def count_nonzero(
550
552
return result
551
553
552
554
553
- def where (condition : Array , x1 : Array , x2 : Array , / ) -> Array :
555
+ def where (
556
+ condition : Array ,
557
+ x1 : Array | bool | int | float | complex ,
558
+ x2 : Array | bool | int | float | complex ,
559
+ / ,
560
+ ) -> Array :
554
561
x1 , x2 = _fix_promotion (x1 , x2 )
555
562
return torch .where (condition , x1 , x2 )
556
563
564
+
557
565
# torch.reshape doesn't have the copy keyword
558
566
def reshape (x : Array ,
559
567
/ ,
@@ -622,7 +630,7 @@ def linspace(start: Union[int, float],
622
630
# torch.full does not accept an int size
623
631
# https://github.com/pytorch/pytorch/issues/70906
624
632
def full (shape : Union [int , Tuple [int , ...]],
625
- fill_value : complex ,
633
+ fill_value : bool | int | float | complex ,
626
634
* ,
627
635
dtype : Optional [DType ] = None ,
628
636
device : Optional [Device ] = None ,
0 commit comments