Skip to content

Commit bc350b3

Browse files
pytorchmergebotpobin6
authored and
pobin6
committed
Revert "Tighten type hints for tensor arithmetic (pytorch#135392)"
This reverts commit d378819. Reverted pytorch#135392 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. See D65641103 for more details ([comment](pytorch#135392 (comment)))
1 parent 439b077 commit bc350b3

File tree

4 files changed

+28
-46
lines changed

4 files changed

+28
-46
lines changed

Diff for: tools/pyi/gen_pyi.py

+20-40
Original file line numberDiff line numberDiff line change
@@ -177,18 +177,14 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
177177
"copy_",
178178
]
179179

180-
shift_ops = (
181-
"lshift",
182-
"rshift",
183-
"ilshift",
184-
"irshift", # inplace ops
185-
)
186-
arithmetic_ops = (
180+
binary_ops = (
187181
"add",
188182
"sub",
189183
"mul",
190184
"div",
191185
"pow",
186+
"lshift",
187+
"rshift",
192188
"mod",
193189
"truediv",
194190
"matmul",
@@ -199,26 +195,24 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
199195
"rtruediv",
200196
"rfloordiv",
201197
"rpow", # reverse arithmetic
202-
"iadd",
203-
"idiv",
204-
"imul",
205-
"isub",
206-
"ifloordiv",
207-
"imod", # inplace ops
208-
)
209-
logic_ops = (
210198
"and",
211199
"or",
212200
"xor",
213201
"rand",
214202
"ror",
215-
"rxor", # reverse logic
203+
"rxor", # logic
204+
"iadd",
216205
"iand",
206+
"idiv",
207+
"ilshift",
208+
"imul",
217209
"ior",
218-
"ixor", # inplace ops
210+
"irshift",
211+
"isub",
212+
"ixor",
213+
"ifloordiv",
214+
"imod", # inplace ops
219215
)
220-
binary_ops = shift_ops + arithmetic_ops + logic_ops
221-
222216
symmetric_comparison_ops = ("eq", "ne")
223217
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
224218
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
@@ -238,28 +232,14 @@ def sig_for_ops(opname: str) -> list[str]:
238232
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
239233

240234
name = opname[2:-2]
241-
if name == "rpow":
242-
return [ # somehow required to make mypy ci happy?
243-
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[has-type]"
244-
]
245-
elif name in arithmetic_ops:
246-
return [
247-
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
248-
]
249-
elif name in logic_ops:
250-
return [f"def {opname}(self, other: Union[Tensor, _bool]) -> Tensor: ..."]
251-
elif name in shift_ops:
252-
return [f"def {opname}(self, other: Union[Tensor, _int]) -> Tensor: ..."]
253-
elif name in symmetric_comparison_ops:
254-
return [
235+
if name in binary_ops:
236+
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
237+
elif name in comparison_ops:
238+
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
239+
if name in symmetric_comparison_ops:
255240
# unsafe override https://github.com/python/mypy/issues/5704
256-
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[override]",
257-
f"def {opname}(self, other: Any) -> _bool: ...",
258-
]
259-
elif name in asymmetric_comparison_ops:
260-
return [
261-
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
262-
]
241+
sig += " # type: ignore[override]"
242+
return [sig]
263243
elif name in unary_ops:
264244
return [f"def {opname}(self) -> Tensor: ..."]
265245
elif name in to_py_type_ops:

Diff for: torch/_decomp/decompositions.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2291,8 +2291,7 @@ def native_batch_norm_backward(
22912291
mean = save_mean_cast
22922292
invstd = save_invstd_cast
22932293
if train:
2294-
assert mean is not None and invstd is not None
2295-
2294+
assert save_mean_cast is not None and save_invstd_cast is not None
22962295
else:
22972296
assert running_mean_cast is not None and running_var_cast is not None
22982297
mean = running_mean_cast

Diff for: torch/_inductor/fx_passes/efficient_conv_bn_eval.py

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def efficient_conv_bn_eval(
3333
"""
3434

3535
assert bn.running_var is not None
36-
assert bn.running_mean is not None
3736

3837
# These lines of code are designed to deal with various cases
3938
# like bn without affine transform, and conv without bias

Diff for: torch/ao/quantization/_equalize.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
128128
"module type not supported:", type(module1), " ", type(module2)
129129
)
130130

131-
bias = get_module_bias(module1) if has_bias(module1) else None
131+
conv1_has_bias = has_bias(module1)
132+
bias = None
132133

133134
weight1 = get_module_weight(module1)
134135
weight2 = get_module_weight(module2)
@@ -139,6 +140,9 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
139140
number input channels of second arg"
140141
)
141142

143+
if conv1_has_bias:
144+
bias = get_module_bias(module1)
145+
142146
weight1_range = channel_range(weight1, output_axis)
143147
weight2_range = channel_range(weight2, input_axis)
144148

@@ -147,7 +151,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
147151
scaling_factors = torch.sqrt(weight1_range / weight2_range)
148152
inverse_scaling_factors = torch.reciprocal(scaling_factors)
149153

150-
if bias is not None:
154+
if conv1_has_bias:
151155
bias = bias * inverse_scaling_factors
152156

153157
# formatting the scaling (1D) tensors to be applied on the given argument tensors
@@ -164,7 +168,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
164168
weight2 = weight2 * scaling_factors
165169

166170
set_module_weight(module1, weight1)
167-
if bias is not None:
171+
if conv1_has_bias:
168172
set_module_bias(module1, bias)
169173
set_module_weight(module2, weight2)
170174

0 commit comments

Comments
 (0)