Skip to content

Commit 0841cf1

Browse files
committed
mx roofline: adjust mxfp8 formulas
Summary: It's not clear whether we can write a fast dim0 + dim1 cast kernel, so adjusting the roofline estimation formulas to use separate dim0 and dim1 kernels Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/20250325_b200_mxfp8_v2_triton.csv --mx_recipe_name mxfp8_cublas --shape_gen_name pow2_extended ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 66a95b3 ghstack-comment-id: 2752441017 Pull Request resolved: #1953
1 parent 36b6545 commit 0841cf1

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

torchao/testing/float8/roofline_utils.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -183,27 +183,16 @@ def get_tensor_memory_traffic_ovhd_s(
183183
"mxfp8_cutlass",
184184
"mxfp8_cublas",
185185
), "unsupported"
186-
187-
if tensor_role == "weight":
188-
# x_bf16 = ...
189-
# kernel 1: x_bf16 -> x_mxfp8_dim0
190-
# kernel 2: x_bf16 -> x_mxfp8_dim1
191-
if fuse_with_prev:
192-
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
193-
else:
194-
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
195-
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
196-
res_bytes = [kernel_1_rw, kernel_2_rw]
186+
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
187+
# x_bf16 = ...
188+
# kernel 1: x_bf16 -> x_mxfp8_dim0
189+
# kernel 2: x_bf16 -> x_mxfp8_dim1
190+
if fuse_with_prev:
191+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
197192
else:
198-
# x_bf16 = ...
199-
# kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1
200-
if fuse_with_prev:
201-
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2
202-
else:
203-
kernel_1_rw = (
204-
BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2
205-
)
206-
res_bytes = [kernel_1_rw]
193+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
194+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
195+
res_bytes = [kernel_1_rw, kernel_2_rw]
207196

208197
# convert from bytes to seconds
209198
res_s = [

0 commit comments

Comments
 (0)