Skip to content

Commit b5682ed

Browse files
committed
Clear mypy errors in scalar/basic.py
1 parent 6cf729b commit b5682ed

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

Diff for: pytensor/scalar/basic.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import builtins
1414
import math
15-
from collections.abc import Callable, Mapping
15+
from collections.abc import Callable
1616
from copy import copy
1717
from itertools import chain
1818
from textwrap import dedent
@@ -59,7 +59,7 @@ class IntegerDivisionError(Exception):
5959
"""
6060

6161

62-
def upcast(dtype, *dtypes):
62+
def upcast(dtype, *dtypes) -> str:
6363
# This tries to keep data in floatX or lower precision, unless we
6464
# explicitly request a higher precision datatype.
6565
keep_float32 = [
@@ -899,31 +899,31 @@ def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
899899
complexs128 = apply_across_args(complex128)
900900

901901

902-
def upcast_out(*types):
902+
def upcast_out(*types) -> tuple[ScalarType]:
903903
dtype = ScalarType.upcast(*types)
904904
return (get_scalar_type(dtype),)
905905

906906

907-
def upcast_out_nobool(*types):
907+
def upcast_out_nobool(*types) -> tuple[ScalarType]:
908908
type = upcast_out(*types)
909909
if type[0] == bool:
910910
raise TypeError("bool output not supported")
911911
return type
912912

913913

914-
def upcast_out_min8(*types):
914+
def upcast_out_min8(*types) -> tuple[ScalarType]:
915915
type = upcast_out(*types)
916916
if type[0] == bool:
917917
return (int8,)
918918
return type
919919

920920

921-
def upgrade_to_float(*types):
921+
def upgrade_to_float(*types) -> tuple[ScalarType]:
922922
"""
923923
Upgrade any int types to float32 or float64 to avoid losing precision.
924924
925925
"""
926-
conv: Mapping[type, type] = {
926+
conv: dict[ScalarType, ScalarType] = {
927927
bool: float32,
928928
int8: float32,
929929
int16: float32,
@@ -934,42 +934,41 @@ def upgrade_to_float(*types):
934934
uint32: float64,
935935
uint64: float64,
936936
}
937-
return (
938-
get_scalar_type(ScalarType.upcast(*[conv.get(type, type) for type in types])),
939-
)
937+
up = ScalarType.upcast(*[conv.get(type, type) for type in types])
938+
return (get_scalar_type(up),)
940939

941940

942-
def upgrade_to_float64(*types):
941+
def upgrade_to_float64(*types) -> tuple[ScalarType]:
943942
"""
944943
Upgrade any int and float32 to float64 to do as SciPy.
945944
946945
"""
947946
return (get_scalar_type("float64"),)
948947

949948

950-
def same_out(type):
949+
def same_out(type: ScalarType) -> tuple[ScalarType]:
951950
return (type,)
952951

953952

954-
def same_out_nobool(type):
953+
def same_out_nobool(type: ScalarType) -> tuple[ScalarType]:
955954
if type == bool:
956955
raise TypeError("bool input not supported")
957956
return (type,)
958957

959958

960-
def same_out_min8(type):
959+
def same_out_min8(type: ScalarType) -> tuple[ScalarType]:
961960
if type == bool:
962961
return (int8,)
963962
return (type,)
964963

965964

966-
def upcast_out_no_complex(*types):
965+
def upcast_out_no_complex(*types) -> tuple[ScalarType]:
967966
if any(type in complex_types for type in types):
968967
raise TypeError("complex type are not supported")
969968
return (get_scalar_type(dtype=ScalarType.upcast(*types)),)
970969

971970

972-
def same_out_float_only(type):
971+
def same_out_float_only(type) -> tuple[ScalarType]:
973972
if type not in float_types:
974973
raise TypeError("only float type are supported")
975974
return (type,)

Diff for: scripts/mypy-failing.txt

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ pytensor/link/numba/dispatch/elemwise.py
1111
pytensor/link/numba/dispatch/scan.py
1212
pytensor/printing.py
1313
pytensor/raise_op.py
14-
pytensor/scalar/basic.py
1514
pytensor/sparse/basic.py
1615
pytensor/sparse/type.py
1716
pytensor/tensor/basic.py

0 commit comments

Comments
 (0)