diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/float8/roofline_utils.py index 6578f1721f..92becb9b94 100644 --- a/torchao/testing/float8/roofline_utils.py +++ b/torchao/testing/float8/roofline_utils.py @@ -183,27 +183,16 @@ def get_tensor_memory_traffic_ovhd_s( "mxfp8_cutlass", "mxfp8_cublas", ), "unsupported" - - if tensor_role == "weight": - # x_bf16 = ... - # kernel 1: x_bf16 -> x_mxfp8_dim0 - # kernel 2: x_bf16 -> x_mxfp8_dim1 - if fuse_with_prev: - kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel - else: - kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel - kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel - res_bytes = [kernel_1_rw, kernel_2_rw] + # For now, assume that we can't profitably fuse kernel 1 and kernel 2 + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel else: - # x_bf16 = ... - # kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1 - if fuse_with_prev: - kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2 - else: - kernel_1_rw = ( - BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2 - ) - res_bytes = [kernel_1_rw] + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_2_rw] # convert from bytes to seconds res_s = [