Skip to content

Commit

Permalink
Add torchcompile_xentropy executor (#1655)
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardofelluga authored Feb 9, 2025
1 parent 71c3b07 commit 5b847bc
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
4 changes: 3 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,18 @@
cudnn_executor: None | extend.Executor = extend.get_executor("cudnn")
sdpa_executor: None | extend.Executor = extend.get_executor("sdpa")
torchcompile_cat_executor: None | extend.Executor = extend.get_executor("torchcompile_cat")
torchcompile_xentropy_executor: None | extend.Executor = extend.get_executor("torchcompile_xentropy")
apex_executor: None | extend.Executor = extend.get_executor("apex")
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
pytorch_executor: None | extend.Executor = extend.get_executor("torch")

# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> nvfuser -> torch -> python]
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python]
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
if nvfuser_executor:
add_default_executor(nvfuser_executor)

if torchcompile_cat_executor and pytorch._dynamo.is_inductor_supported():
add_default_executor(torchcompile_xentropy_executor)
add_default_executor(torchcompile_cat_executor)

if sdpa_executor:
Expand Down
35 changes: 35 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,41 @@ def cuda_device_checker(*args, **kwargs):
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops
}

# Similar to torchcomile_cat, this executor is meant to be used with nvfuser_executor to allow
# inductor to claim cross_entropy computation.
required_ops = {
"nll_loss_backward",
"log_softmax_backward",
"torch.log_softmax",
"torch.nn.functional.nll_loss",
"torch.nn.functional.cross_entropy",
}
torch_compile_xentropy_ex = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops)
register_executor(torch_compile_xentropy_ex)

supported_ops = {
prims.broadcast_in_dim.id,
prims.convert_element_type.id,
prims.div.id,
prims.ne.id,
prims.neg.id,
prims.pad.id,
prims.reshape.id,
prims.slice_prim.id,
prims.where.id,
"nll_loss_backward",
"log_softmax_backward",
"torch.log_softmax",
"torch.nn.functional.nll_loss",
"torch.sum",
"torch.take_along_dim",
"torch.Tensor.contiguous",
}

torch_compile_xentropy_ex._implmap = {
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops
}


torch_compile_ex = TorchCompileExecutor(name="torchcompile")
register_executor(torch_compile_ex)
Expand Down
20 changes: 19 additions & 1 deletion thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ def version(self):
return torch.__version__


class TorchCompileXentropyTestExecutor(TestExecutor):
name = "torchcompile_xentropy"
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
supported_dtypes = (datatypes.dtype,)

def is_available(self) -> bool:
return not IS_WINDOWS

def executors_list(self) -> list[extend.Executor]:
from thunder.executors.torch_compile import torch_compile_cat_ex

return [torch_compile_cat_ex]

def version(self):
return torch.__version__


class TorchCompileCatTestExecutor(TestExecutor):
name = "torchcompile_cat"
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
Expand Down Expand Up @@ -261,6 +278,7 @@ def make_callable(self, fn, **kwargs):
# TODO Refactor these executors into the actual executor (sub)modules
TorchExecutor: TorchTestExecutor = TorchTestExecutor()
TorchCompileCatExecutor: TorchCompileCatTestExecutor = TorchCompileCatTestExecutor()
TorchCompileXentropyExecutor: TorchCompileXentropyTestExecutor = TorchCompileXentropyTestExecutor()
TorchCompileExecutor: TorchCompileTestExecutor = TorchCompileTestExecutor()
DynamoThunderExecutor: DynamoThunderTestExecutor = DynamoThunderTestExecutor()
nvFuserExecutor: None | nvFuserTestExecutor = None
Expand Down Expand Up @@ -368,7 +386,7 @@ def __init__(
self.supported_executors = (
set(supported_executors)
if supported_executors is not None
else set(_all_test_executors() + [TorchCompileCatExecutor])
else set(_all_test_executors() + [TorchCompileCatExecutor, TorchCompileXentropyExecutor])
)
for ex in self.supported_executors:
assert isinstance(ex, TestExecutor)
Expand Down
1 change: 1 addition & 0 deletions thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_get_all_executors_includes_all_native_executors():
"sdpa",
"torchcompile",
"torchcompile_cat",
"torchcompile_xentropy",
"python",
"transformer_engine",
}
Expand Down
24 changes: 23 additions & 1 deletion thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from torch._dynamo import is_inductor_supported

import thunder
from thunder.executors.torch_compile import supported_ops, torch_compile_ex, torch_compile_cat_ex
from thunder.executors.torch_compile import (
supported_ops,
torch_compile_ex,
torch_compile_cat_ex,
torch_compile_xentropy_ex,
)
from thunder.executors.torchex import ex as pytorch_ex
from thunder.executors.nvfuserex import nvfuserex
from thunder.tests.bf16 import device_supports_bf16
Expand Down Expand Up @@ -122,3 +127,20 @@ def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor:
out_jitted = forward_and_loss_jitted(model, input_ids)

assert_close(out, out_jitted)


@requiresCUDA
def test_torch_compile_xentropy_loss():
from transformers.loss.loss_utils import ForCausalLMLoss

logits = torch.randn(1, 2, 6, device="cuda", requires_grad=True)
labels = torch.randint(0, 6, (1, 2), device="cuda")
vocab_size = 6

closs_fn = thunder.jit(ForCausalLMLoss, executors=[torch_compile_xentropy_ex])
_ = closs_fn(logits, labels, vocab_size, ignore_index=-1)
forward_trace = thunder.last_traces(closs_fn)[-1].python()

# make a single torch.compile region
assert "TorchCompile0" in forward_trace
assert "TorchCompile1" not in forward_trace

0 comments on commit 5b847bc

Please sign in to comment.