Skip to content

Commit edfd224

Browse files
committed
check tensor attrs of tensor wrapper subclasses in prologue
also use `pytorch_executor` in the `transform_for_execution` of `prologue_trace` as it could have the prim of tensor subclass flattening whose definition is only available in pytorch executor. Signed-off-by: Masaki Kozuki <[email protected]>
1 parent d9ed305 commit edfd224

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

thunder/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def get_computation_and_inputs(*args, **kwargs):
592592

593593
prologue_traces += transform_for_execution(
594594
prologue_trc,
595-
executors_list=(pythonex,),
595+
executors_list=(pythonex, pytorch_executor),
596596
use_del_last_used=False,
597597
)
598598
prologue_trc = prologue_traces[-1]

thunder/clang/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020
from thunder.core import utils
2121
import thunder.core.prims as prims
2222
from thunder.core.proxies import (
23+
AnyProxy,
2324
IntegerProxy,
24-
NumberProxy,
2525
NumberLike,
26+
NumberProxy,
27+
Proxy,
28+
SubclassTensorProxy,
2629
TensorProxy,
27-
pyval,
28-
pytype,
2930
proxy,
30-
AnyProxy,
31-
Proxy,
31+
pytype,
32+
pyval,
3233
)
3334
import thunder.core.devices as devices
3435

@@ -67,7 +68,7 @@ def __call__(self, fn: Callable) -> Callable:
6768

6869
# Checks a tensor's shape and metadata (for use with cache check)
6970
@clangop()
70-
def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None:
71+
def check_tensor_shape_and_metadata(t: TensorProxy | SubclassTensorProxy, /) -> None:
7172
return prims.check_tensor_shape_and_metadata(
7273
t,
7374
# replace Proxy entries with `-1`s as wild card, as we any value is

thunder/core/jit_ext.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ class JITSharpEdgeError(RuntimeError):
146146
def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
147147
sharp_edges: SHARP_EDGES_OPTIONS = get_jit_ctx().sharp_edges
148148

149-
s: str = (
150-
f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
151-
)
149+
s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
152150

153151
if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
154152
return do_raise(JITSharpEdgeError(s))
@@ -1719,9 +1717,12 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:
17191717

17201718
with tracectx(prologue_trace):
17211719
for prim, *args in ctx._constraints:
1720+
subclass_tensor: SubclassTensorProxy | None = None
17221721
for a in args:
17231722
if isinstance(a, Proxy):
17241723
unpack(a)
1724+
if isinstance(a, SubclassTensorProxy):
1725+
subclass_tensor = a
17251726
# unpacking Proxy in TensorProxy.shape which is used in `check_tensor_shape_and_metadata`
17261727
if prim == clang.check_tensor_shape_and_metadata:
17271728
for s in a.shape:
@@ -1730,6 +1731,13 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:
17301731

17311732
prim(*args)
17321733

1734+
if isinstance(subclass_tensor, SubclassTensorProxy):
1735+
for t in prims.flatten_tensor_subclass(subclass_tensor):
1736+
for s in t.shape:
1737+
if isinstance(s, Proxy):
1738+
unpack(s)
1739+
prim(t)
1740+
17331741
cache_info = thunder._get_cache_info()
17341742
# assert len of cache info to ensure that we're not missing anything?
17351743
if cache_info:

0 commit comments

Comments
 (0)