Skip to content

Commit 6e6b077

Browse files
committed
updates for MLP with torchao.float8
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent a534003 commit 6e6b077

File tree

9 files changed

+257
-5
lines changed

9 files changed

+257
-5
lines changed

thunder/core/jit_ext.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
667667
So far, non-tensor ``ctx`` attributes seem to be folded into a trace.
668668
"""
669669
from thunder.core.baseutils import check, sequencify
670+
from thunder.core.transforms import dce
670671

671672
custom_autograd_function_cls = unwrap(obj)
672673
custom_forward = custom_autograd_function_cls.forward
@@ -678,6 +679,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
678679
)
679680
if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
680681
return trace_of_fwd
682+
trace_of_fwd = dce(trace_of_fwd)
681683

682684
# Forward.
683685
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
@@ -691,6 +693,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
691693
for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args)
692694
]
693695
trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols
696+
trace_of_fwd = dce(trace_of_fwd)
694697

695698
@wraps(trace_of_fwd.python_callable())
696699
def core_of_forward(*args, **kwargs):
@@ -737,6 +740,7 @@ def core_of_forward(*args, **kwargs):
737740
for a in filter(lambda a: isinstance(a, Proxy), trace_of_backward.args)
738741
]
739742
trace_of_backward.bound_symbols = bwd_unpack_bsyms + trace_of_backward.bound_symbols
743+
trace_of_backward = dce(trace_of_backward)
740744

741745
bwd_trace_impl = TraceCtx()
742746
bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)
@@ -770,6 +774,24 @@ def grad_transform(*args, **kwargs):
770774
execution_transform=core_of_forward,
771775
grad_transform=grad_transform,
772776
)
777+
778+
added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
779+
import_ctx, call_ctx, object_ctx = {}, {}, {}
780+
for bsym in trace_of_fwd.bound_symbols:
781+
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
782+
import_ctx.update(cur_import_ctx)
783+
call_ctx.update(cur_call_ctx)
784+
object_ctx.update(cur_object_ctx)
785+
786+
if import_ctx:
787+
added_bsym._import_ctx.update(import_ctx)
788+
if call_ctx:
789+
if added_bsym._call_ctx is not None:
790+
added_bsym._call_ctx.update(call_ctx)
791+
else:
792+
added_bsym._call_ctx = call_ctx
793+
if object_ctx:
794+
added_bsym._object_ctx.update(object_ctx)
773795
return forward_result
774796

775797

thunder/core/proxies.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,26 @@ def distparallel_type(self):
15021502
def thunder_fsdp_padding_size(self):
15031503
return self._thunder_fsdp_padding_size
15041504

1505+
# n.b.(crcrpar): just returning contiguous for `_make_wrapper_subclasses`
1506+
def stride(self) -> Sequence[int]:
1507+
shape = self.shape
1508+
if len(shape) == 1:
1509+
return (1,)
1510+
elif len(shape) == 0:
1511+
return tuple()
1512+
else:
1513+
import numpy
1514+
1515+
_stride = reversed(numpy.cumprod([1] + list(shape[1:])).tolist())
1516+
return tuple(_stride)
1517+
1518+
def storage_offset(self) -> int:
1519+
return -1
1520+
1521+
@property
1522+
def layout(self) -> torch.layout:
1523+
return torch.strided
1524+
15051525
# We need to implement `__len__` as
15061526
# > In addition to bypassing any instance attributes in the
15071527
# > interest of correctness, implicit special method lookup

thunder/core/pytree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from functools import partial
23
from types import FunctionType
34
import dataclasses
@@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
6465
and not is_likely_from_collections_namedtuple(args)
6566
and not dataclasses.is_dataclass(args)
6667
and not type(args).__module__.startswith("torch.return_types")
68+
and not issubclass(type(args), Enum)
6769
):
6870
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
6971
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)

thunder/core/trace_interpreter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,11 @@ def add_to_swap_map(old, new):
128128
old = old.replace(shape=new._shape)
129129

130130
if isinstance(new, VJPDual):
131-
swap_map[variableify(new.primal)] = old
132-
new.primal = old
131+
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`,
132+
# seems to cause `new.primal` == `old`, leading to a cycle in swapping.
133+
if (key := variableify(new.primal)) != variableify(old):
134+
swap_map[variableify(new.primal)] = old
135+
new.primal = old
133136
else:
134137
assert isinstance(new, ProxyInterface), (old, new)
135138
swap_map[variableify(new)] = old

thunder/core/transform_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
165165
# may mark some of the operation's outputs as unused
166166
some_unused = False
167167
for out in bsym.flat_proxy_outs:
168-
if variableify(out) in needed_proxies and producer_map[out] == bsym:
168+
if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym:
169169
needed = True
170170
else:
171171
some_unused = True

thunder/executors/torch_compile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ def _to_torch(*args, **kwargs) -> Any:
5656
if torch_op is None:
5757
raise RuntimeError("op not found for {bsym.sym.name}")
5858

59+
# NOTE(crcrpar): Currently `ltorch.t` is mapped to `torchex.transpose`
60+
# thus `args` needs to be updated to have dim0 and dim1
61+
if bsym.sym.id == "torch.t":
62+
utils.check(len(args) == 1, lambda: f"{bsym.sym.id} takes only one argument but {args=}")
63+
args = args + (0, 1)
64+
5965
return torch_op(*args, **kwargs)
6066

6167
return _to_torch

thunder/executors/torchex.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,13 +1403,44 @@ def _copy_with_setitem_impl(a, key, value):
14031403
#
14041404

14051405
matmul = _register_torch_operation("matmul")
1406+
_scaled_mm = _register_torch_operation("_scaled_mm")
14061407
outer = _register_torch_operation("outer")
14071408

14081409
_register_implementation(prims.matmul, matmul, checker=_always_executable)
14091410

14101411
_register_implementation(ltorch.matmul, matmul, checker=_always_executable)
14111412
_register_implementation(ltorch.outer, outer, checker=_always_executable)
14121413

1414+
1415+
def _scaled_mm_transform(
1416+
a: TensorLike,
1417+
b: TensorLike,
1418+
scale_a: TensorLike,
1419+
scale_b: TensorLike,
1420+
bias: TensorLike | None = None,
1421+
scale_result: TensorLike | None = None,
1422+
out_dtype: dtypeLike | None = None,
1423+
use_fast_accum: bool = False,
1424+
):
1425+
1426+
def is_column_major(mat: TensorLike) -> bool:
1427+
return mat.stride()[0] == 1 and mat.stride()[0] > 1
1428+
1429+
result_dtype: torch.dtype = to_torch_dtype(a.dtype if out_dtype is None else out_dtype)
1430+
if not is_column_major(b):
1431+
b = b.t().contiguous().t()
1432+
1433+
return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum)
1434+
1435+
1436+
_register_implementation(
1437+
ltorch._scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform
1438+
)
1439+
_register_implementation(
1440+
ltorch.core_aten_scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform
1441+
)
1442+
1443+
14131444
#
14141445
# Normalization operations
14151446
#

thunder/tests/test_tensor_subclass.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
from __future__ import annotations
22
from typing import TYPE_CHECKING
33

4+
from lightning_utilities.core.imports import package_available
45
import pytest
56
import torch
7+
import torch.nn as nn
68
from torch.utils import _pytree as pytree
79

810
import thunder
9-
from thunder.tests.framework import instantiate
11+
from thunder.dynamo.compiler import ThunderCompiler
12+
from thunder.tests.framework import (
13+
DynamoThunderExecutor,
14+
TorchExecutor,
15+
instantiate,
16+
nvFuserExecutor,
17+
)
1018
from thunder.tests.make_tensor import make_tensor
1119

1220
if TYPE_CHECKING:
1321
from typing import Any
1422

1523

24+
TORCHAO_AVAILABLE = package_available("torchao")
25+
26+
1627
@torch._dynamo.allow_in_graph
1728
class EncapsulateXandScale(torch.autograd.Function):
1829
@staticmethod
@@ -232,3 +243,71 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
232243
torch.testing.assert_close(expected, actual)
233244
if requires_grad:
234245
actual.mean().backward()
246+
247+
248+
@instantiate(
249+
dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16),
250+
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
251+
executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor),
252+
decorators=(
253+
pytest.mark.skipif(
254+
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
255+
reason="Requires capability >= 8.9 and torchao",
256+
),
257+
pytest.mark.parametrize("bias", (True, False)),
258+
),
259+
)
260+
def test_torchao_float8_linear(executor, device, dtype, bias):
261+
from torchao.float8 import convert_to_float8_training
262+
263+
batch_size, in_features, out_features = 16, 32, 64
264+
device = torch.device("cuda")
265+
torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype)
266+
267+
model = nn.Sequential(
268+
nn.Linear(in_features, out_features, bias=bias),
269+
nn.GELU(approximate="tanh"),
270+
nn.Linear(out_features, out_features, bias=bias),
271+
).to(device=device, dtype=torch_dtype)
272+
fp8_model = convert_to_float8_training(model)
273+
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype)
274+
275+
expected: torch.Tensor
276+
jitted: nn.Module
277+
backend: ThunderCompiler | None = None
278+
279+
if is_thunderfx := executor == DynamoThunderExecutor:
280+
torch._dynamo.reset()
281+
expected = torch.compile(fp8_model)(x)
282+
backend = ThunderCompiler()
283+
jitted = torch.compile(fp8_model, backend=backend)
284+
else:
285+
expected = fp8_model(x)
286+
jitted = executor.make_callable(fp8_model)
287+
288+
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor:
289+
with pytest.raises(
290+
RuntimeError, match="Failed to compute the min-cut on the graph due to a path with infinite capacity"
291+
):
292+
jitted(x)
293+
return
294+
actual = jitted(x)
295+
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor:
296+
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
297+
torch.testing.assert_close(actual, expected)
298+
return
299+
300+
if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (
301+
not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor
302+
):
303+
pytest.xfail("numerical error")
304+
torch.testing.assert_close(actual, expected)
305+
306+
# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
307+
# Currently no subgraphs go to thunder.jit.
308+
if is_thunderfx:
309+
for subgraph in backend.subgraph_infos:
310+
if not bias and dtype == thunder.core.dtypes.bfloat16:
311+
assert not subgraph.thunder_compiled_fns
312+
else:
313+
assert subgraph.thunder_compiled_fns

thunder/torch/__init__.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,9 @@ def t(a: TensorLike, /) -> TensorLike:
14061406
lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D",
14071407
RuntimeError,
14081408
)
1409-
return prims.transpose(a, (1, 0)) if a.ndim == 2 else a
1409+
if a.ndim != 2:
1410+
return a
1411+
return transpose(a, 0, 1)
14101412

14111413

14121414
@torchsymbol(torch.ops.aten.t.default, id="torch.ops.aten.t.default")
@@ -1480,6 +1482,17 @@ def core_aten_transpose(a: TensorProxy, dim0: int, dim1: int) -> TensorProxy:
14801482
return _transpose_impl(a, dim0, dim1)
14811483

14821484

1485+
def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
1486+
fwd = transpose(a, dim0, dim1)
1487+
g = get_grad(fwd)
1488+
a_grad = transpose(g, dim0, dim1)
1489+
put_grad(a, a_grad)
1490+
return fwd
1491+
1492+
1493+
register_grad(transpose, _transpose_grad)
1494+
1495+
14831496
@torchsymbol(torch.unbind, is_method=True)
14841497
def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]:
14851498
utils.check(
@@ -4282,6 +4295,82 @@ def outer(a: TensorLike, b: TensorLike, /) -> TensorLike:
42824295
return a[:, None] * b[None, :]
42834296

42844297

4298+
# TODO(crcrpar): Add nvfuser support of `matmul(a.float() * scale_a, b.float() * scale_b) + bias`
4299+
# So far I haven't managed to get a nice result from nvfuser region as I left
4300+
# https://github.com/Lightning-AI/lightning-thunder/pull/1415/files#r1892875183
4301+
# reference: https://github.com/pytorch/pytorch/blob/6d4cd3e/torch/_meta_registrations.py#L5566
4302+
def _scaled_mm_impl(
4303+
a: TensorLike,
4304+
b: TensorLike,
4305+
scale_a: TensorLike,
4306+
scale_b: TensorLike,
4307+
bias: TensorLike | None = None,
4308+
scale_result: TensorLike | None = None,
4309+
out_dtype: dtypeLike | None = None,
4310+
use_fast_accum: bool = False,
4311+
) -> TensorLike:
4312+
fp8_dtypes = {dtypes.float8_e4m3fn, dtypes.float8_e4m3fnuz, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz}
4313+
# TODO(crcrpar): Devise a way to make sure `a` is row-major and `b` is column-major.
4314+
utils.check(
4315+
(
4316+
(a.ndim == 2 and b.ndim == 2)
4317+
and (a.shape[1] == b.shape[0])
4318+
and (a.shape[1] % 16 == 0 and b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
4319+
and (to_dtype(a.dtype) in fp8_dtypes and to_dtype(b.dtype) in fp8_dtypes)
4320+
and not (a.dtype == dtypes.float8_e5m2 and b.dtype == dtypes.float8_e5m2)
4321+
and to_device(a.device).type == "cuda"
4322+
),
4323+
lambda: f"data matrices of {a=} and {b=} do not satisfy the condition.",
4324+
)
4325+
args = [a, b, scale_a, scale_b]
4326+
if bias is not None:
4327+
args.append(bias)
4328+
utils.check_same_device(args)
4329+
utils.check(
4330+
(
4331+
(scale_a.numel() == 1 and scale_b.numel() == 1)
4332+
and (scale_a.dtype == dtypes.float32 and scale_b.dtype == dtypes.float32)
4333+
),
4334+
lambda: f"Only tensor-wise scaling is supported but {scaled_a.shape = } and {scaled_b.shape = }",
4335+
exception_type=NotImplementedError,
4336+
)
4337+
result_dtype = a.dtype if out_dtype is None else to_dtype(out_dtype)
4338+
return TensorProxy(
4339+
like=a,
4340+
shape=(a.shape[0], b.shape[1]),
4341+
device=a.device,
4342+
dtype=result_dtype,
4343+
)
4344+
4345+
4346+
@torchsymbol(torch._scaled_mm)
4347+
def _scaled_mm(
4348+
a: TensorLike,
4349+
b: TensorLike,
4350+
scale_a: TensorLike,
4351+
scale_b: TensorLike,
4352+
bias: TensorLike | None = None,
4353+
scale_result: TensorLike | None = None,
4354+
out_dtype: dtypeLike | None = None,
4355+
use_fast_accum: bool = False,
4356+
) -> TensorLike:
4357+
return _scaled_mm_impl(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum)
4358+
4359+
4360+
@torchsymbol(torch.ops.aten._scaled_mm.default, id="torch.ops.aten._scaled_mm")
4361+
def core_aten_scaled_mm(
4362+
a: TensorLike,
4363+
b: TensorLike,
4364+
scale_a: TensorLike,
4365+
scale_b: TensorLike,
4366+
bias: TensorLike | None = None,
4367+
scale_result: TensorLike | None = None,
4368+
out_dtype: dtypeLike | None = None,
4369+
use_fast_accum: bool = False,
4370+
) -> TensorLike:
4371+
return _scaled_mm_impl(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum)
4372+
4373+
42854374
#
42864375
# Normalization operations
42874376
#

0 commit comments

Comments
 (0)