Skip to content

Commit

Permalink
Merge branch 'main' into benchmark_adam
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardofelluga authored Feb 21, 2025
2 parents 1809269 + 78d84e5 commit 9be0bd4
Show file tree
Hide file tree
Showing 31 changed files with 1,573 additions and 222 deletions.
5 changes: 5 additions & 0 deletions docs/source/intermediate/benchmarking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,8 @@ As seen earlier, it's possible to write benchmarks for models and not just stand
benchmark_for_compute_type(compute_type, benchmark, fn, *args, **kwargs)

And that's as simple as that! Just add the decorator ``@parametrize_compute_type`` after your parametrization, add the ``compute_type`` argument, and use ``benchmark_for_compute_type`` to call the benchmark function.

Isolate benchmarks to avoid OutOfMemory errors
----------------------------------------------

When running multiple benchmarks in sequence, ``pytest`` does not always do a good job cleaning up, and sometimes it happens that, while they work when called standalone, benchmarks fail anyway. The main problem we observed is that memory is not entirely freed before running the next benchmark, therefore the option ``--isolate-benchmarks`` comes in rescue. It will separate the benchmark runs, creating a sub-process for each benchmark configuration and run them one after the other. Logs of failures will be saved in the ``failed_benchmarks_logs`` folder and benchmark results will be saved in the form of json in the ``benchmarks_reports`` folder unless the ``THUNDER_BENCH_DIR`` environment variable is specified.
4 changes: 2 additions & 2 deletions notebooks/adding_custom_operator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@
"source": [
"Our new operator has the following signature `sincos(x: Tensor) -> Tuple[Tensor, Tensor]`. It takes a tensor as input and returns a tuple of two tensors. The first tensor is the sine of the input and the second tensor is the cosine of the input.\n",
"\n",
"We call all callables that should be recorded in the trace *Symbols*. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n",
"We call all callables that should be recorded in the trace *Symbols*. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite operators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n",
"\n",
"The easiest way to register a new operator is through defining a meta - defining how the metadata of the output looks like give the metadata of the inputs and an implementation (dealing with concrete objects like Python `Number`s and PyTorch `Tensor`s) and register both of them through an executor. This will automatically create a symbol for us.\n",
"The easiest way to register a new operator is through defining a meta - defining how the metadata of the output looks like given the metadata of the inputs and an implementation (dealing with concrete objects like Python `Number`s and PyTorch `Tensor`s) and register both of them through an executor. This will automatically create a symbol for us.\n",
"\n",
"So we create an executor:"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch >=2.3.0
looseversion ==1.3.0
lightning-utilities >=0.7.0
numpy >=1.23.0,<2 # not yet ready for numpy 2
numpy
networkx >= 3.3
optree >=0.12.1
opt_einsum >= 3.3.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/notebooks.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ipython[all] ~=8.30.0
numpy >=1.23.0,<2 # not yet ready for numpy 2
numpy
liger-kernel == 0.4.0
cuda-python
litgpt == 0.5.1
4 changes: 2 additions & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ pytest-xdist ==3.6.1
pytest-random-order ==1.1.1
pytest-timestamper ==0.0.10
graphviz ==0.20.3
fdm ==0.4.1
fdm ==0.5.0
expecttest ==0.3.0 # for test_ddp.py
hypothesis ~=6.124.7 # for test_ddp.py
numpy >=1.23.0,<2 # for test_ops.py; not yet ready for numpy 2
numpy
einops # for test_einops.py
litgpt==0.4.11 # for the model definition in tests and benchmarks
absl-py # thunder/benchmarks/test_benchmark_litgpt.py
Expand Down
100 changes: 100 additions & 0 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,106 @@ def fn(self) -> Callable:
return model


class DeepSeekSGLangMoEBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (see linked file for details)
# Adapted from
# https://github.com/sgl-project/sglang/blob/de5533341ee3c1b7667b1eb1f209b6825335d136/python/sglang/srt/layers/moe/topk.py#L23
@staticmethod
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
topk_weights = torch.nn.functional.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

# Adapted from
# https://github.com/sgl-project/sglang/blob/d23cb9a01ed7f7e39f40e3f5ad7d271d3ac52ce2/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py#L76
@staticmethod
def fused_moe_def(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
) -> torch.Tensor:
assert not use_fp8_w8a8, "Not supported"

topk_weights, topk_ids = DeepSeekSGLangMoEBenchmark.fused_topk_native(
hidden_states=x,
gating_output=input_gating,
topk=topk,
renormalize=True,
)
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = torch.nn.functional.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))

def __init__(
self,
*,
model: str,
tp_size: int,
batch_size: int,
use_fp8: bool,
) -> None:
super().__init__()

from transformers import AutoConfig

config = AutoConfig.from_pretrained(model, trust_remote_code=True)
self.num_experts = config.n_routed_experts
self.topk = config.num_experts_per_tok
self.hidden_size = config.hidden_size
self.dtype = config.torch_dtype
intermediate_size = config.intermediate_size
self.shard_intermediate_size = 2 * intermediate_size // tp_size
self.use_fp8 = use_fp8
self.num_tokens = batch_size

def make_batch(self) -> tuple[list, dict]:
make = partial(make_tensor, device="cuda", requires_grad=False)

x = make((self.num_tokens, self.hidden_size), dtype=self.dtype)

if self.use_fp8:
init_dtype = self.dtype
w1 = make((self.num_experts, self.shard_intermediate_size, self.hidden_size), dtype=init_dtype)
w2 = make((self.num_experts, self.hidden_size, self.shard_intermediate_size // 2), dtype=init_dtype)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = make((self.num_experts,), dtype=torch.float32)
w2_scale = make((self.num_experts,), dtype=torch.float32)
a1_scale = make((1,), dtype=torch.float32)
a2_scale = make((1,), dtype=torch.float32)
else:
w1 = make((self.num_experts, self.shard_intermediate_size, self.hidden_size), dtype=self.dtype)
w2 = make((self.num_experts, self.hidden_size, self.shard_intermediate_size // 2), dtype=self.dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None

input_gating = make((self.num_tokens, self.num_experts), dtype=torch.float32)
return (x, w1, w2, input_gating, self.topk, self.use_fp8, w1_scale, w2_scale, a1_scale, a2_scale), {}

def fn(self) -> Callable:
return DeepSeekSGLangMoEBenchmark.fused_moe_def


class BatchNormBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
def __init__(
self,
Expand Down
96 changes: 95 additions & 1 deletion thunder/benchmarks/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,110 @@
import os
from os import path
import platform
import psutil
from typing import Any
import warnings
import importlib.util
from pytest import hookimpl, TestReport, Item, Parser
import multiprocessing as mp
import subprocess

BENCHMARK_JSON_DIR = "benchmarks_reports"
FAILED_BENCHMARK_LOGS_DIR = "failed_benchmarks_logs"

def pytest_addoption(parser):

def pytest_addoption(parser: Parser):
# CLI option to specify where to store the benchmark results in asv format.
# If not set or None, results won't be saved in asv.
parser.addoption("--asv_bench_dir", action="store", default=os.getenv("THUNDER_BENCH_DIR"))

parser.addoption("--isolate-benchmarks", action="store_true", default=False)


def launch_benchmark(target_file, target_name: str):
target_filename = target_name.replace("/", "_")

target_json = path.join(BENCHMARK_JSON_DIR, f"{target_filename}.json")
target_log = path.join(FAILED_BENCHMARK_LOGS_DIR, f"{target_filename}.log")

with open(target_log, "w") as target_log_file:
subprocess.run(
[
"pytest",
f"{target_file}::{target_name}",
"-vs",
"--benchmark-json",
target_json,
],
check=True,
text=True,
stderr=subprocess.STDOUT,
stdout=target_log_file,
)


def run_in_isolation(item: Item) -> TestReport:
process = mp.Process(
target=launch_benchmark,
args=(
item.location[0],
item.name,
),
)
process.start()
process.join()

# Will mark skip as passed because pytest returns error only if there are failed tests.
outcome = "failed" if process.exitcode != 0 else "passed"
target_filename = item.name.replace("/", "_")

if outcome == "passed":
test_log = path.join(FAILED_BENCHMARK_LOGS_DIR, f"{target_filename}.log")
os.remove(test_log)

benchmark_json = path.join(BENCHMARK_JSON_DIR, f"{target_filename}.json")
if outcome == "failed" or path.getsize(benchmark_json) == 0:
os.remove(benchmark_json)

return TestReport(item.nodeid, item.location, keywords=item.keywords, outcome=outcome, longrepr=None, when="call")


@hookimpl(tryfirst=True)
def pytest_runtest_protocol(item: Item, nextitem: Item):
# If the option was not passed, let pytest manage the run.
if not item.config.getoption("--isolate-benchmarks"):
return None

ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
test_report = run_in_isolation(item)

ihook.pytest_runtest_logreport(report=test_report)
ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
return True


def pytest_runtestloop(session):
global BENCHMARK_JSON_DIR, FAILED_BENCHMARK_LOGS_DIR

if not session.config.getoption("--isolate-benchmarks"):
return None

mp.set_start_method("spawn")

from _pytest.terminal import TerminalReporter

terminal: TerminalReporter = session.config.pluginmanager.get_plugin("terminalreporter")

custom_report_dir = os.getenv("THUNDER_BENCH_DIR")
BENCHMARK_JSON_DIR = custom_report_dir if custom_report_dir else BENCHMARK_JSON_DIR

os.makedirs(BENCHMARK_JSON_DIR, exist_ok=True)
os.makedirs(FAILED_BENCHMARK_LOGS_DIR, exist_ok=True)

terminal.write_line(f"Saving failed benchmarks logs in {FAILED_BENCHMARK_LOGS_DIR}")
terminal.write_line(f"Saving benchmarks reports in {BENCHMARK_JSON_DIR}")


def pytest_sessionfinish(session, exitstatus):
# Save result only if the pytest session was a benchmark.
Expand Down
38 changes: 38 additions & 0 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TorchbenchBenchmark,
HFBenchmark,
LinearLoRABenchmark,
DeepSeekSGLangMoEBenchmark,
thunder_apex_executor,
thunder_apex_nvfuser_executor,
thunder_cudnn_executor,
Expand Down Expand Up @@ -732,6 +733,43 @@ def backward_fn(result, forward_inputs, output_grads):
return backward_fn, backward_setup


# Thunder executor: RuntimeError: Advanced indexing currently only supports zero or one-dimensional integer tensors,
# but found a tensor with dtype thunder.dtypes.int64 and 2 dimensions
# https://github.com/Lightning-AI/lightning-thunder/issues/764
moe_executors = (torch_executor, torch_compile_executor, thunderfx_executor)
moe_executors_ids = (
"torch",
"torch.compile",
"thunderfx",
)


@pytest.mark.parametrize(
"bs,",
(2**i for i in range(0, 6)),
ids=(f"bs{2**i}" for i in range(0, 6)),
)
@pytest.mark.parametrize(
"executor,",
moe_executors,
ids=moe_executors_ids,
)
@pytest.mark.parametrize(
"compute_type,",
(ComputeType.INFERENCE,),
ids=("inference",),
)
def test_deepseek_sglang_moe(benchmark, bs, executor: Callable, compute_type: ComputeType):
bench: Benchmark = DeepSeekSGLangMoEBenchmark(
model="deepseek-ai/DeepSeek-R1", tp_size=8, batch_size=bs, use_fp8=False
)

args, kwargs = bench.make_batch()
fn = executor(bench.fn())

benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


#
# interpreter benchmarks
#
Expand Down
1 change: 0 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def tag_tensorproxy_output_as_detached(proxy):
lambda: f"A symbol {self} was called while processing a primitive",
exception_type=AssertionError,
)
assert symbols_list is not None

symbols_list.append(bsym)
return result
Expand Down
13 changes: 2 additions & 11 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,20 +1430,11 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):
def _log_sigmoid_grad(
a: TensorProxy,
) -> TensorProxy:
from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid
from thunder.torch import where, exp, logsigmoid

fwd = logsigmoid(a)

g = get_grad(fwd)
if a.device.type == "cpu":
# NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
buffer = exp(-abs(a))
a_grad = log_sigmoid_backward(g, a, buffer)
else:
# Here a placeholder tensor is provided.
placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype)
a_grad = log_sigmoid_backward(g, a, placeholder_buffer)
a_grad = g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a)))
put_grad(a, a_grad)

return fwd
Expand Down
Loading

0 comments on commit 9be0bd4

Please sign in to comment.