Skip to content

Commit a678eaf

Browse files
pianpwkpytorchmergebot
authored andcommitted
check fake/real mismatches during real tensor prop (pytorch#137747)
Summary: While testing exportability for PT2 Inference models, we found various cases of invalid op inputs during tracing, for example errors like: `a and b must have same reduction dim`, `expected scalar type Long but found Int`, etc. Looking more closely, these happened to due the same few meta kernels & eager kernels producing mismatched outputs upstream (e.g. different output tensor dtype, int output). Adding checks to catch mismatched outputs in real tensor prop upstream, so errors are raised at the mismatched op, instead of the downstream ops taking them as inputs. Relies a lot on utils from [CrossRefFakeMode](https://github.com/pytorch/pytorch/blob/929797dedbf23376123ce95230c01a7e3b71e130/torch/_subclasses/fake_utils.py#L78) Follow ups: could add more checks, and maybe have a flag to only enable these for cases like draft mode, so perf doesn't suffer? Test Plan: test_export, test_fake_tensor Differential Revision: D64210055 Pull Request resolved: pytorch#137747 Approved by: https://github.com/zou3519
1 parent 9919932 commit a678eaf

File tree

7 files changed

+313
-53
lines changed

7 files changed

+313
-53
lines changed

test/export/test_export.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,120 @@ def forward(self, x):
10781078
ep_model = export(model, (x,), strict=False).module()
10791079
self.assertTrue(torch.allclose(model(x), ep_model(x)))
10801080

1081+
def test_real_tensor_size_mismatch(self):
1082+
from torch._subclasses.fake_tensor import MetadataMismatchError
1083+
1084+
class M(torch.nn.Module):
1085+
def forward(self, a, b):
1086+
return torch.ops.mylib.foo(a, b)
1087+
1088+
@torch.library.custom_op("mylib::foo", mutates_args={})
1089+
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
1090+
return a + b
1091+
1092+
@foo.register_fake
1093+
def foo_fake_impl(a, b):
1094+
m, n = a.shape
1095+
return torch.empty(n, m) # incorrectly permute
1096+
1097+
error_type = (
1098+
MetadataMismatchError
1099+
if is_non_strict_test(self._testMethodName)
1100+
else torch._dynamo.exc.TorchRuntimeError
1101+
)
1102+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
1103+
# won't catch anything if dims are equal
1104+
export(
1105+
M(),
1106+
(torch.randn(4, 4), torch.randn(4, 4)),
1107+
)
1108+
# catch concrete inequality
1109+
with self.assertRaisesRegex(
1110+
error_type,
1111+
"Real tensor propagation found an output size mismatch between fake shape 8 and real shape 4, "
1112+
"at output index 0, dimension 0 for func: mylib.foo.default",
1113+
):
1114+
export(
1115+
M(),
1116+
(torch.randn(4, 8), torch.randn(4, 8)),
1117+
)
1118+
# same test with dynamic shapes
1119+
d0 = Dim("d0")
1120+
d1 = Dim("d1")
1121+
export(
1122+
M(),
1123+
(torch.randn(4, 4), torch.randn(4, 4)),
1124+
dynamic_shapes={
1125+
"a": (d0, d1),
1126+
"b": (d0, d1),
1127+
},
1128+
)
1129+
with self.assertRaisesRegex(
1130+
error_type,
1131+
"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
1132+
"at output index 0, dimension 0 for func: mylib.foo.default",
1133+
):
1134+
export(
1135+
M(),
1136+
(torch.randn(4, 8), torch.randn(4, 8)),
1137+
dynamic_shapes={
1138+
"a": (d0, d1),
1139+
"b": (d0, d1),
1140+
},
1141+
)
1142+
1143+
def test_real_tensor_alias_dtype_mismatch(self):
1144+
from torch._subclasses.fake_tensor import MetadataMismatchError
1145+
1146+
error_type = (
1147+
MetadataMismatchError
1148+
if is_non_strict_test(self._testMethodName)
1149+
else torch._dynamo.exc.TorchRuntimeError
1150+
)
1151+
1152+
# test alias case
1153+
class M(torch.nn.Module):
1154+
def forward(self, a):
1155+
return torch.ops.mylib.foo_alias(a)
1156+
1157+
@torch.library.custom_op("mylib::foo_alias", mutates_args={})
1158+
def foo_alias(a: torch.Tensor) -> torch.Tensor:
1159+
return a * 2
1160+
1161+
@foo_alias.register_fake
1162+
def foo_fake_impl(a):
1163+
return a
1164+
1165+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
1166+
with self.assertRaisesRegex(
1167+
error_type,
1168+
r"Real tensor propagation found an aliasing mismatch between fake output (.*\n)*.* "
1169+
r"and real output (.*\n)*.* for func: mylib.foo_alias.default",
1170+
):
1171+
ep = export(M(), (torch.randn(4, 4),))
1172+
1173+
# test dtype case
1174+
class N(torch.nn.Module):
1175+
def forward(self, a):
1176+
return torch.ops.mylib.foo_dtype(a)
1177+
1178+
@torch.library.custom_op("mylib::foo_dtype", mutates_args={})
1179+
def foo_dtype(a: torch.Tensor) -> torch.Tensor:
1180+
return a * 2
1181+
1182+
@foo_dtype.register_fake
1183+
def foo_fake_impl(a):
1184+
m, n = a.shape
1185+
return torch.empty([m, n], dtype=torch.int32)
1186+
1187+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
1188+
with self.assertRaisesRegex(
1189+
error_type,
1190+
r"Real tensor propagation found a metadata mismatch between fake tensor (.*\n)*.* "
1191+
r"and real tensor (.*\n)*.* at output index 0, for func: mylib.foo_dtype.default",
1192+
):
1193+
ep = export(N(), (torch.randn(4, 4),))
1194+
10811195
def test_real_tensor_for_max_op(self):
10821196
class Foo(torch.nn.Module):
10831197
def forward(self, x, y):

test/test_fake_tensor.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_CacheKeyState,
2929
DynamicOutputShapeException,
3030
extract_tensor_metadata,
31+
MetadataMismatchError,
3132
FakeTensor,
3233
FakeTensorConverter,
3334
FakeTensorMode,
@@ -1377,14 +1378,20 @@ def forward(self, arg1, arg2, arg3):
13771378
try:
13781379
with torch._subclasses.CrossRefFakeMode():
13791380
Repro()(*args)
1380-
except RuntimeError as e:
1381+
except MetadataMismatchError as e:
13811382
# We expect the cross ref to succed for the first output to fail
13821383
# for the rng state, see Note [Seed and Offset]
13831384
self.assertTrue("output[0]" not in str(e))
1384-
self.assertTrue(
1385-
"found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!"
1386-
in str(e)
1387-
)
1385+
if self.__class__.__name__.startswith("PropagateRealTensors"):
1386+
self.assertTrue(
1387+
"Real tensor propagation found a metadata mismatch"
1388+
in str(e)
1389+
)
1390+
else:
1391+
self.assertTrue(
1392+
"found mismatched tensor metadata for output"
1393+
in str(e)
1394+
)
13881395

13891396
# IMPORTANT!!! Always run even if CUDA is not available
13901397
def test_fake_gpu_no_init(self):

torch/_meta_registrations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,12 @@ def _compute_reduction_shape(self, dims, keepdim):
21312131
def device_hint(tensor) -> "str":
21322132
if isinstance(tensor, torch._subclasses.FakeTensor):
21332133
return tensor.fake_device.type
2134+
elif (
2135+
hasattr(tensor, "device")
2136+
and hasattr(tensor.device, "type")
2137+
and tensor.device.type != "meta"
2138+
):
2139+
return tensor.device.type
21342140
else:
21352141
return "cuda" # default to cuda
21362142

torch/_prims_common/__init__.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _maybe_get_pytype(t):
136136
def compare_tensor_meta(
137137
a: TensorLikeType,
138138
b: TensorLikeType,
139+
check_sizes=True,
139140
check_strides=False,
140141
*,
141142
allow_rhs_unbacked=False,
@@ -148,16 +149,20 @@ def compare_tensor_meta(
148149
In the future this will validate additional metadata, like
149150
strides.
150151
"""
152+
from torch._subclasses.fake_tensor import MetadataMismatchError
153+
151154
assert isinstance(a, TensorLike)
152155
assert isinstance(b, TensorLike)
153156

154-
if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):
157+
if check_sizes and not same_shape(
158+
a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked
159+
):
155160
msg = f"Shapes {a.shape} and {b.shape} are not equal!"
156-
raise AssertionError(msg)
161+
raise MetadataMismatchError(msg)
157162

158163
if a.dtype != b.dtype:
159164
msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!"
160-
raise AssertionError(msg)
165+
raise MetadataMismatchError(msg)
161166

162167
if a.device != b.device:
163168
# Handles special cuda:0 vs cuda case
@@ -168,27 +173,27 @@ def compare_tensor_meta(
168173
pass
169174
else:
170175
msg = f"Devices {a.device} and {b.device} are not equal!"
171-
raise AssertionError(msg)
176+
raise MetadataMismatchError(msg)
172177

173178
# Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050
174179
if check_strides:
175180
same_strides, idx = check_significant_strides(a, b)
176181
if not same_strides:
177182
msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!"
178-
raise RuntimeError(msg)
183+
raise MetadataMismatchError(msg)
179184

180185
if a.storage_offset() != b.storage_offset():
181186
msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!"
182-
raise RuntimeError(msg)
187+
raise MetadataMismatchError(msg)
183188

184189
if check_conj:
185190
if a.is_conj() != b.is_conj():
186-
raise RuntimeError(
191+
raise MetadataMismatchError(
187192
f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}"
188193
)
189194

190195
if a.is_neg() != b.is_neg():
191-
raise RuntimeError(
196+
raise MetadataMismatchError(
192197
f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}"
193198
)
194199

torch/_subclasses/fake_tensor.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ class UnsupportedOperatorException(RuntimeError):
140140
func: OpOverload
141141

142142

143+
@dataclass
144+
class MetadataMismatchError(RuntimeError):
145+
reason: str
146+
147+
143148
def ordered_set(*items: T) -> Dict[T, Literal[True]]:
144149
return dict.fromkeys(items, True)
145150

@@ -2031,6 +2036,11 @@ def maybe_to_real_tensor(
20312036
def maybe_propagate_real_tensors(fake_out: T) -> T:
20322037
import sympy
20332038

2039+
from torch._subclasses.fake_utils import (
2040+
_check_alias_info,
2041+
_check_fake_real_tensors,
2042+
)
2043+
20342044
log.debug("maybe_propagate_real_tensors %s", func)
20352045

20362046
def go(t: object, real_t: Tensor) -> None:
@@ -2057,6 +2067,33 @@ def go(t: object, real_t: Tensor) -> None:
20572067
assert self.shape_env is not None
20582068
self.shape_env.set_unbacked_var_to_val(s, int(real_t))
20592069

2070+
def _check_fake_real_vals(fake: Any, real: Any) -> None:
2071+
# use real values + ShapeEnv to check mismatches between potentially symbolic values
2072+
if isinstance(fake, (SymInt, SymFloat)):
2073+
# symbolic expression, ask ShapeEnv to substitute known backed/unbacked values
2074+
assert self.shape_env is not None
2075+
if (
2076+
not fake.node.expr.free_symbols
2077+
- self.shape_env.var_to_val.keys()
2078+
- self.shape_env.unbacked_var_to_val.keys()
2079+
):
2080+
if (
2081+
self.shape_env._maybe_evaluate_static(
2082+
sympy.Eq(fake.node.expr, real), compute_hint=True
2083+
)
2084+
is not sympy.S.true
2085+
):
2086+
raise MetadataMismatchError(
2087+
f"mismatch between fake value {fake} and real value {real} "
2088+
)
2089+
elif isinstance(
2090+
fake, (int, float, bool)
2091+
): # concrete value, check direct equality
2092+
if fake != real:
2093+
raise MetadataMismatchError(
2094+
f"mismatch between fake value {fake} and real value {real} "
2095+
)
2096+
20602097
if real_out is not nil:
20612098
if (
20622099
not isinstance(fake_out, Tensor)
@@ -2073,6 +2110,65 @@ def go(t: object, real_t: Tensor) -> None:
20732110
else:
20742111
tree_map_(go, fake_out, real_out)
20752112

2113+
# check fake/real alias info
2114+
try:
2115+
_check_alias_info(
2116+
"Real tensor propagation found",
2117+
real_out,
2118+
(real_args, real_kwargs),
2119+
fake_out,
2120+
(args, kwargs),
2121+
)
2122+
except MetadataMismatchError as exc:
2123+
raise MetadataMismatchError(
2124+
f"Real tensor propagation found an aliasing mismatch between "
2125+
f"fake output {fake_out} and real output {real_out}, "
2126+
f" for func: {func}"
2127+
) from exc
2128+
2129+
# check fake/real tensor properies, sizes & output values
2130+
for i, (_real_out, _fake_out) in enumerate(
2131+
zip(pytree.tree_leaves(real_out), pytree.tree_leaves(fake_out))
2132+
):
2133+
if isinstance(_fake_out, torch.Tensor):
2134+
try:
2135+
_check_fake_real_tensors(
2136+
_fake_out,
2137+
_real_out,
2138+
context="Real tensor propagation found",
2139+
sizes=False, # manual check below
2140+
strides=False, # skip strides
2141+
storage_offset=True,
2142+
requires_grad=False, # issues with FakeTensorConverter preserving requires_grad
2143+
)
2144+
except MetadataMismatchError as exc:
2145+
raise MetadataMismatchError(
2146+
f"Real tensor propagation found a metadata mismatch between "
2147+
f"fake tensor {_fake_out} and real tensor {_real_out}, "
2148+
f" at output index {i}, for func: {func}"
2149+
) from exc
2150+
2151+
for j, (s_fake, s_real) in enumerate(
2152+
zip(_fake_out.size(), _real_out.size())
2153+
):
2154+
try:
2155+
_check_fake_real_vals(s_fake, s_real)
2156+
except MetadataMismatchError as exc:
2157+
raise MetadataMismatchError(
2158+
f"Real tensor propagation found an output size mismatch between "
2159+
f"fake shape {s_fake} and real shape {s_real}, at output "
2160+
f"index {i}, dimension {j} for func: {func}"
2161+
) from exc
2162+
else:
2163+
try:
2164+
_check_fake_real_vals(_fake_out, _real_out)
2165+
except MetadataMismatchError as exc:
2166+
raise MetadataMismatchError(
2167+
f"Real tensor propagation found an output value mismatch between "
2168+
f"fake output value {_fake_out} and real output value {_real_out}, "
2169+
f" at output index {i}, for func: {func}"
2170+
) from exc
2171+
20762172
# If a data-dependent op is used in a decomposition, we
20772173
# may need to get the unbacked settings "early"
20782174
# TODO: Is this really needed?

0 commit comments

Comments
 (0)