12
12
13
13
import builtins
14
14
import math
15
- from collections .abc import Callable , Mapping
15
+ from collections .abc import Callable
16
16
from copy import copy
17
17
from itertools import chain
18
18
from textwrap import dedent
@@ -59,7 +59,7 @@ class IntegerDivisionError(Exception):
59
59
"""
60
60
61
61
62
- def upcast (dtype , * dtypes ):
62
+ def upcast (dtype , * dtypes ) -> str :
63
63
# This tries to keep data in floatX or lower precision, unless we
64
64
# explicitly request a higher precision datatype.
65
65
keep_float32 = [
@@ -899,31 +899,31 @@ def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
899
899
complexs128 = apply_across_args (complex128 )
900
900
901
901
902
- def upcast_out (* types ):
902
+ def upcast_out (* types ) -> tuple [ ScalarType ] :
903
903
dtype = ScalarType .upcast (* types )
904
904
return (get_scalar_type (dtype ),)
905
905
906
906
907
- def upcast_out_nobool (* types ):
907
+ def upcast_out_nobool (* types ) -> tuple [ ScalarType ] :
908
908
type = upcast_out (* types )
909
909
if type [0 ] == bool :
910
910
raise TypeError ("bool output not supported" )
911
911
return type
912
912
913
913
914
- def upcast_out_min8 (* types ):
914
+ def upcast_out_min8 (* types ) -> tuple [ ScalarType ] :
915
915
type = upcast_out (* types )
916
916
if type [0 ] == bool :
917
917
return (int8 ,)
918
918
return type
919
919
920
920
921
- def upgrade_to_float (* types ):
921
+ def upgrade_to_float (* types ) -> tuple [ ScalarType ] :
922
922
"""
923
923
Upgrade any int types to float32 or float64 to avoid losing precision.
924
924
925
925
"""
926
- conv : Mapping [ type , type ] = {
926
+ conv : dict [ ScalarType , ScalarType ] = {
927
927
bool : float32 ,
928
928
int8 : float32 ,
929
929
int16 : float32 ,
@@ -934,42 +934,41 @@ def upgrade_to_float(*types):
934
934
uint32 : float64 ,
935
935
uint64 : float64 ,
936
936
}
937
- return (
938
- get_scalar_type (ScalarType .upcast (* [conv .get (type , type ) for type in types ])),
939
- )
937
+ up = ScalarType .upcast (* [conv .get (type , type ) for type in types ])
938
+ return (get_scalar_type (up ),)
940
939
941
940
942
- def upgrade_to_float64 (* types ):
941
+ def upgrade_to_float64 (* types ) -> tuple [ ScalarType ] :
943
942
"""
944
943
Upgrade any int and float32 to float64 to do as SciPy.
945
944
946
945
"""
947
946
return (get_scalar_type ("float64" ),)
948
947
949
948
950
- def same_out (type ) :
949
+ def same_out (type : ScalarType ) -> tuple [ ScalarType ] :
951
950
return (type ,)
952
951
953
952
954
- def same_out_nobool (type ) :
953
+ def same_out_nobool (type : ScalarType ) -> tuple [ ScalarType ] :
955
954
if type == bool :
956
955
raise TypeError ("bool input not supported" )
957
956
return (type ,)
958
957
959
958
960
- def same_out_min8 (type ) :
959
+ def same_out_min8 (type : ScalarType ) -> tuple [ ScalarType ] :
961
960
if type == bool :
962
961
return (int8 ,)
963
962
return (type ,)
964
963
965
964
966
- def upcast_out_no_complex (* types ):
965
+ def upcast_out_no_complex (* types ) -> tuple [ ScalarType ] :
967
966
if any (type in complex_types for type in types ):
968
967
raise TypeError ("complex type are not supported" )
969
968
return (get_scalar_type (dtype = ScalarType .upcast (* types )),)
970
969
971
970
972
- def same_out_float_only (type ):
971
+ def same_out_float_only (type ) -> tuple [ ScalarType ] :
973
972
if type not in float_types :
974
973
raise TypeError ("only float type are supported" )
975
974
return (type ,)
0 commit comments