Skip to content

Commit a534003

Browse files
committed
trace transform of tensor wrapper subclass
to support `__torch_dispatch__`. Since it extends the behavior that is implemented in C++ level, we'd need to apply the transform to split forward and backward traces separately. Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 55cc826 commit a534003

File tree

9 files changed

+1056
-8
lines changed

9 files changed

+1056
-8
lines changed

docs/source/reference/transforms/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ thunder.transforms
88

99
MaterializationTransform
1010
ConstantFolding
11+
unroll_tensor_subclasses

thunder/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from thunder.core.interpreter import print_interpreter_log, print_to_log
7474
from thunder.core.jit_ext import thunder_general_jit
7575
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
76+
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses
7677

7778
# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
7879
import torch as pytorch
@@ -369,7 +370,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
369370
data_ptr_to_tensor_group_index = {}
370371
tensor_group_index_to_tensor_indices = defaultdict(list)
371372
for idx, t in enumerate(flat_args):
372-
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
373+
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
373374
data_ptr = t.untyped_storage().data_ptr()
374375
if data_ptr not in data_ptr_to_tensor_group_index:
375376
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
@@ -616,6 +617,7 @@ def get_computation_and_inputs(*args, **kwargs):
616617
computation_trc = dce(computation_trc)
617618
computation_traces.append(computation_trc)
618619

620+
_unroll_tensor_subclasses_applied = False
619621
backward_trc = None
620622
if not cd.disable_torch_autograd_support:
621623
tensor_cls = (pytorch.Tensor, TensorProxy)
@@ -626,10 +628,15 @@ def get_computation_and_inputs(*args, **kwargs):
626628
# transform_for_execution and various sorting of symbols,
627629
# applying transform_for_execution after this would be
628630
# breaking the order of operations
631+
_unroll_tensor_subclasses_applied = True
629632
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
630633
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
631634
# by split_forward_backward
632635

636+
if not _unroll_tensor_subclasses_applied:
637+
computation_trc = unroll_tensor_subclasses(computation_trc)
638+
computation_traces.append(computation_trc)
639+
633640
if backward_trc is None:
634641
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
635642
from thunder.executors.passes import _transform_for_operator_executor_execution

thunder/core/prims.py

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ class PrimIDs(Enum):
280280
SINK = auto()
281281
# Tensor Subclasses methods
282282
TENSOR_SUBCLASS_CTOR = auto()
283+
FLATTEN_TENSOR_SUBCLASS = auto()
284+
UNFLATTEN_TENSOR_SUBCLASS = auto()
283285

284286

285287
class OpTags(Enum):
@@ -4098,7 +4100,7 @@ def check_types(coll):
40984100
return tuple(types_set)
40994101

41004102

4101-
def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]:
4103+
def filter_types_for_tensor_wrapper_subclass(types: tuple[Any, ...]) -> tuple[Any, ...]:
41024104
return tuple(
41034105
filter(
41044106
lambda t: (
@@ -4170,7 +4172,7 @@ def printer_of_tensor_subclass_ctor(
41704172
filtered_types = (cls,)
41714173
if non_tensors:
41724174
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors])
4173-
filtered_types += filter_types(types)
4175+
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
41744176
new_imports = {t.__name__: t for t in filtered_types}
41754177
bsym._import_ctx.update(new_imports)
41764178
return s
@@ -4183,7 +4185,7 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
41834185
filtered_types: tuple[Any, ...] = (cls,)
41844186
if non_tensors:
41854187
types = get_nested_types(non_tensors)
4186-
filtered_types += filter_types(types)
4188+
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
41874189
new_imports = {t.__name__: t for t in filtered_types}
41884190
bsym._import_ctx.update(new_imports)
41894191

@@ -4195,3 +4197,163 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
41954197
python_printer=printer_of_tensor_subclass_ctor,
41964198
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
41974199
)
4200+
4201+
4202+
def printer_of_tensor_subclass_flatten(
4203+
bsym: BoundSymbol,
4204+
out_printables: Any,
4205+
arg_printables: Sequence[Printable],
4206+
kwarg_printables: dict[str, Printable],
4207+
) -> str | Iterable[str]:
4208+
from itertools import chain
4209+
4210+
arg_str = (
4211+
""
4212+
if (arg_printables is None or len(arg_printables) == 0)
4213+
else ", ".join(codeutils.prettyprint(x) for x in arg_printables)
4214+
)
4215+
4216+
result_str: str
4217+
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
4218+
result_str = ""
4219+
else:
4220+
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "
4221+
4222+
# Creates a comment describing the output
4223+
comment_str = ""
4224+
if isinstance(bsym.output, Proxy):
4225+
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"
4226+
4227+
s = f"{result_str}{arg_str}.__tensor_flatten__(){comment_str}"
4228+
4229+
if bsym.header:
4230+
header_lines = (
4231+
bsym.header
4232+
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
4233+
else bsym.header.splitlines()
4234+
)
4235+
header_lines = (f"# {line}" for line in header_lines)
4236+
return chain(header_lines, [s])
4237+
4238+
return s
4239+
4240+
4241+
# NOTE(crcrpar): The behavior is different from PyTorch `subclass_tensor.__tensor_flatten__()`
4242+
# that returns a list of tensor attr names and a dict of const metadata. In Thunder traces,
4243+
# const values could be obviated and actual tensor proxies would be more useful
4244+
# than tensor attr names.
4245+
def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]:
4246+
tensor_attr_names, metadata = t.__tensor_flatten__()
4247+
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
4248+
return tensors
4249+
4250+
4251+
flatten_tensor_subclass = make_prim(
4252+
PrimIDs.FLATTEN_TENSOR_SUBCLASS,
4253+
"flatten_tensor_subclass",
4254+
meta=flatten_tensor_subclass_meta,
4255+
python_printer=printer_of_tensor_subclass_flatten,
4256+
)
4257+
4258+
4259+
def printer_of_unflatten_tensor_subclass(
4260+
bsym: BoundSymbol,
4261+
out_printables: Any,
4262+
arg_printables: Sequence[Printable],
4263+
kwarg_printables: dict[str, Printable],
4264+
) -> str | Iterable[str]:
4265+
from itertools import chain
4266+
4267+
wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
4268+
if isinstance(wrapped_cls, torch._C._TensorMeta):
4269+
cls = wrapped_cls
4270+
else:
4271+
cls: torch._C._TensorMeta = wrapped_cls.obj
4272+
4273+
arg_str = (
4274+
""
4275+
if (arg_printables is None or len(arg_printables) == 0)
4276+
else ", ".join(codeutils.prettyprint(x) for x in arg_printables[1:])
4277+
)
4278+
kwarg_str: str
4279+
4280+
if len(kwarg_printables) == 0:
4281+
kwarg_str = ""
4282+
else:
4283+
kwarg_str = ", ".join(f"{k}={codeutils.prettyprint(v)}" for k, v in kwarg_printables.items())
4284+
4285+
result_str: str
4286+
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
4287+
result_str = ""
4288+
else:
4289+
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "
4290+
4291+
# Creates a comment describing the output
4292+
comment_str = ""
4293+
if isinstance(bsym.output, Proxy):
4294+
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"
4295+
4296+
s = f"{result_str}{cls.__name__}.__tensor_unflatten__({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"
4297+
4298+
if bsym.header:
4299+
header_lines = (
4300+
bsym.header
4301+
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
4302+
else bsym.header.splitlines()
4303+
)
4304+
header_lines = (f"# {line}" for line in header_lines)
4305+
return chain(header_lines, [s])
4306+
4307+
return s
4308+
4309+
4310+
def bind_postprocess_of_unflatten_tensor_subclass(bsym: BoundSymbol) -> None:
4311+
cls = bsym.args[0]
4312+
inner_tensors = bsym.args[1]
4313+
metadata = bsym.args[2]
4314+
4315+
filtered_types: tuple[Any, ...] = (cls,)
4316+
if metadata:
4317+
types = get_nested_types(list(metadata.values()))
4318+
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
4319+
new_imports = {t.__name__: t for t in filtered_types}
4320+
bsym._import_ctx.update(new_imports)
4321+
4322+
4323+
def unflatten_tensor_subclass_meta(
4324+
tensor_subclass_type,
4325+
inner_tensors: dict[str, TensorProxy],
4326+
metadata: dict[str, Any],
4327+
) -> SubclassTensorProxy:
4328+
first_tensor: TensorProxy = list(inner_tensors.values())[0]
4329+
a = SubclassTensorProxy(
4330+
shape=first_tensor.shape,
4331+
device=first_tensor.device,
4332+
dtype=first_tensor.dtype,
4333+
requires_grad=first_tensor.requires_grad,
4334+
tensors=list(inner_tensors.values()),
4335+
non_tensors=list(metadata.values()),
4336+
subclass_type=tensor_subclass_type,
4337+
)
4338+
for name, value in inner_tensors.items():
4339+
setattr(a, name, value)
4340+
for name, value in metadata.items():
4341+
setattr(a, name, value)
4342+
return a
4343+
4344+
4345+
def unflatten_tensor_subclass_python_impl(
4346+
tensor_subclass_type,
4347+
inner_tensors: dict[str, TensorProxy],
4348+
metadata: dict[str, Any],
4349+
) -> torch.Tensor:
4350+
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)
4351+
4352+
4353+
unflatten_tensor_subclass = make_prim(
4354+
PrimIDs.UNFLATTEN_TENSOR_SUBCLASS,
4355+
"unflatten_tensor_subclass",
4356+
meta=unflatten_tensor_subclass_meta,
4357+
python_printer=printer_of_unflatten_tensor_subclass,
4358+
_bind_postprocess=bind_postprocess_of_unflatten_tensor_subclass,
4359+
)

thunder/core/proxies.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,6 +2111,7 @@ def __setattr__(self, name, value):
21112111

21122112
# TODO: move this function to jit_ext.py
21132113
def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
2114+
from torch._subclasses.fake_tensor import FakeTensor
21142115
from thunder.core.interpreter import ProvenanceRecord, PseudoInst, wrap_const
21152116

21162117
if hasattr(t, "_thunder_device"):
@@ -2145,8 +2146,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
21452146
else:
21462147
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
21472148
shape = tuple(t.shape)
2148-
return TensorProxy(
2149-
name,
2149+
ctor_kwargs = dict(
2150+
name=name,
21502151
shape=tuple(shape),
21512152
device=device,
21522153
dtype=dtype,
@@ -2156,6 +2157,39 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
21562157
history=history,
21572158
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
21582159
)
2160+
# n.b.(crcrpar): :class:`thunder.dynamo.ThunderCompiler.__call__` takes torch.fx GraphModule
2161+
# where `FakeTensor` seems to be used, leading to failures observed in e.g.
2162+
# https://github.com/Lightning-AI/lightning-thunder/actions/runs/11689709564/job/32553053319#step:10:5747
2163+
# https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=219328&view=logs&jobId=5b0799f7-725e-5b16-9b83-c0a5a25d03f0&j=5b0799f7-725e-5b16-9b83-c0a5a25d03f0
2164+
if (
2165+
isinstance(t, torch.Tensor)
2166+
and type(t) not in (torch.Tensor, torch.nn.Parameter, FakeTensor)
2167+
and hasattr(t, "__tensor_flatten__")
2168+
and hasattr(t, "__tensor_unflatten__")
2169+
):
2170+
baseutils.check(
2171+
hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"),
2172+
lambda: f"{t=} seems to be a tensor subclass but not traceable",
2173+
)
2174+
tensor_attr_names, metadata = t.__tensor_flatten__()
2175+
tensors = [tensorproxy(getattr(t, name), name=None, history=history) for name in tensor_attr_names]
2176+
ctor_kwargs.update(
2177+
{
2178+
"tensors": tensors,
2179+
"non_tensors": list(metadata.values()),
2180+
"subclass_type": type(t),
2181+
}
2182+
)
2183+
p = SubclassTensorProxy(**ctor_kwargs)
2184+
p._tensor_attr_names = tensor_attr_names
2185+
p._non_tensor_attr_names = list(metadata.keys())
2186+
for name, tensor in zip(tensor_attr_names, tensors):
2187+
setattr(p, name, tensor)
2188+
for name, value in metadata.items():
2189+
setattr(p, name, value)
2190+
return p
2191+
else:
2192+
return TensorProxy(**ctor_kwargs)
21592193

21602194

21612195
def futuretensorproxy(

thunder/executors/torch_autograd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
132132
from thunder.distributed.transforms import FSDPCommBucketing
133133
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
134134
from thunder.executors.passes import del_last_used, transform_for_execution
135+
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses
135136

136137
utils.check(compile_data is not None, lambda: "`compile_data` is required")
137138
# NOTE: This function is rather slow, so it's intended to be used
@@ -158,6 +159,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
158159
fw_traces = [fw_trace]
159160
bw_traces = [bw_trace]
160161

162+
fw_trace = unroll_tensor_subclasses(fw_trace)
163+
fw_traces.append(fw_trace)
164+
161165
from thunder.distributed import FSDPType
162166

163167
# only enable rematerialize_params_in_backward when using FSDP ZeRO3
@@ -262,6 +266,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
262266
if getattr(compile_data.fn, "use_fsdp", False):
263267
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)
264268

269+
bw_trace = unroll_tensor_subclasses(bw_trace)
270+
bw_traces.append(bw_trace)
271+
265272
# Now we can run the optimization passes on the backward trace
266273
# TODO Restore request for no rematerialization
267274
bw_extrace = transform_for_execution(

thunder/executors/torchex.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2235,13 +2235,13 @@ def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensor
22352235

22362236

22372237
def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
2238-
from thunder.core.prims import get_nested_types, filter_types
2238+
from thunder.core.prims import get_nested_types, filter_types_for_tensor_wrapper_subclass
22392239

22402240
cls, _name, _shape, _device, _dtype, _requires_grad, _tensors, non_tensors = bsym.args
22412241
filtered_types = (cls,)
22422242
if non_tensors:
22432243
types = get_nested_types(non_tensors)
2244-
filtered_types += filter_types(types)
2244+
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
22452245
new_imports = {t.__name__: t for t in filtered_types}
22462246
bsym._import_ctx.update(new_imports)
22472247

@@ -2254,3 +2254,47 @@ def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
22542254
python_printer=prims.printer_of_tensor_subclass_ctor,
22552255
)
22562256
_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable)
2257+
2258+
2259+
def flatten_tensor_subclass_impl(t):
2260+
tensor_attr_names, metadata = t.__tensor_flatten__()
2261+
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
2262+
return tensors
2263+
2264+
2265+
flatten_tensor_subclass = ex.register_operator(
2266+
"flatten_tensor_subclass",
2267+
meta=prims.flatten_tensor_subclass.meta,
2268+
fn=flatten_tensor_subclass_impl,
2269+
)
2270+
_register_implementation(
2271+
prims.flatten_tensor_subclass,
2272+
flatten_tensor_subclass,
2273+
checker=_always_executable,
2274+
)
2275+
2276+
2277+
def unflatten_tensor_subclass_impl(
2278+
tensor_subclass_type: torch._C._TensorMeta,
2279+
inner_tensors: dict[str, TensorLike],
2280+
metadata: dict,
2281+
):
2282+
for key in metadata:
2283+
v = metadata[key]
2284+
if isinstance(v, dtypes.dtype):
2285+
metadata[key] = to_torch_dtype(v)
2286+
elif isinstance(v, devices.Device):
2287+
metadata[key] = to_torch_device(v)
2288+
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)
2289+
2290+
2291+
unflatten_tensor_subclass = ex.register_operator(
2292+
"unflatten_tensor_subclass",
2293+
meta=prims.unflatten_tensor_subclass.meta,
2294+
fn=unflatten_tensor_subclass_impl,
2295+
)
2296+
_register_implementation(
2297+
prims.unflatten_tensor_subclass,
2298+
unflatten_tensor_subclass,
2299+
checker=_always_executable,
2300+
)

0 commit comments

Comments
 (0)