Skip to content

Commit 522f5b8

Browse files
[float8nocompile] add triton kernel which does fp8 conversion to col major and transpose in col major at once (#1566)
1 parent 5e59b51 commit 522f5b8

File tree

2 files changed

+236
-2
lines changed

2 files changed

+236
-2
lines changed

torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py

+160-2
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def to_fp8_col_major_t(
250250
block_col_offs[:, None] * output_stride_row
251251
+ block_row_offs[None, :] * output_stride_col
252252
)
253-
out_mask = (block_row_offs[:, None] < output_num_rows) & (
254-
block_col_offs[None, :] < output_num_cols
253+
out_mask = (block_col_offs[:, None] < output_num_rows) & (
254+
block_row_offs[None, :] < output_num_cols
255255
)
256256
tl.store(out_ptr + out_offs, fp8_vals, mask=out_mask)
257257

@@ -381,6 +381,77 @@ def _to_fp8_row_major_t_and_non_t(
381381
tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask)
382382

383383

384+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
385+
@triton.jit
386+
def _to_fp8_col_major_t_and_non_t(
387+
input_ptr,
388+
col_major_out_ptr,
389+
col_major_t_out_ptr,
390+
scale_ptr,
391+
num_elements: int,
392+
fp8_dtype_min: float,
393+
fp8_dtype_max: float,
394+
input_num_rows: int,
395+
input_num_cols: int,
396+
input_stride_row: int,
397+
input_stride_col: int,
398+
col_major_out_stride_row: int,
399+
col_major_out_stride_col: int,
400+
col_major_t_out_stride_row: int,
401+
col_major_t_out_stride_col: int,
402+
input_dtype: tl.constexpr,
403+
output_dtype: tl.constexpr,
404+
BLOCK_SIZE_ROWS: tl.constexpr,
405+
BLOCK_SIZE_COLS: tl.constexpr,
406+
EPS: tl.constexpr,
407+
):
408+
"""
409+
Reads a row-major, high precision input tensor and writes 2 output tensors:
410+
1) fp8 col major tensor (transposed)
411+
2) fp8 col major tensor
412+
"""
413+
# col major tranposed
414+
block_row_id = tl.program_id(axis=0)
415+
block_col_id = tl.program_id(axis=1)
416+
417+
# load scaling factor
418+
scale = tl.load(scale_ptr).to(tl.float32)
419+
420+
# load block of input tensor
421+
block_row_start = block_row_id * BLOCK_SIZE_ROWS
422+
block_col_start = block_col_id * BLOCK_SIZE_COLS
423+
block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS)
424+
block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS)
425+
input_offs = (
426+
block_row_offs[:, None] * input_stride_row
427+
+ block_col_offs[None, :] * input_stride_col
428+
)
429+
mask = (block_row_offs[:, None] < input_num_rows) & (
430+
block_col_offs[None, :] < input_num_cols
431+
)
432+
vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype)
433+
434+
# perform conversion
435+
vals = vals * scale
436+
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
437+
438+
# 1. write col-major output
439+
out_offs = block_row_offs[:, None] + block_col_offs[None, :] * input_num_rows
440+
tl.store(col_major_out_ptr + out_offs, fp8_vals, mask=mask)
441+
442+
# 2. write tranposed col-major output
443+
col_major_t_num_rows = input_num_cols
444+
col_major_t_num_cols = input_num_rows
445+
out_offs = (
446+
block_col_offs[:, None] * col_major_t_out_stride_row
447+
+ block_row_offs[None, :] * col_major_t_out_stride_col
448+
)
449+
out_mask = (block_col_offs[:, None] < col_major_t_num_rows) & (
450+
block_row_offs[None, :] < col_major_t_num_cols
451+
)
452+
tl.store(col_major_t_out_ptr + out_offs, fp8_vals.trans(1, 0), mask=out_mask)
453+
454+
384455
@triton.autotune(configs=kernel_configs_1D, key=["num_elements"])
385456
@triton.jit
386457
def _amax_atomic(
@@ -859,6 +930,93 @@ def hp_to_fp8_row_major_t_and_non_t(
859930
return fp8_tensor_row_major, fp8_tensor_row_major_t
860931

861932

933+
def hp_to_fp8_col_major_t_and_non_t(
934+
hp_tensor: torch.Tensor,
935+
fp8_dtype: torch.dtype,
936+
linear_mm_config: LinearMMConfig,
937+
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
938+
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
939+
) -> Float8Tensor:
940+
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"
941+
942+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
943+
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]
944+
945+
fp8_dtype_min = torch.finfo(fp8_dtype).min
946+
fp8_dtype_max = torch.finfo(fp8_dtype).max
947+
948+
# compute scaling factor for tensor
949+
scale = _hp_tensor_to_scale(
950+
hp_tensor,
951+
tl_input_dtype,
952+
fp8_dtype_max,
953+
algo,
954+
)
955+
956+
# perform fp8 conversion
957+
input_num_rows, input_num_cols = hp_tensor.shape
958+
num_elements = hp_tensor.numel()
959+
960+
# preallocate necessary output tensors
961+
fp8_output_col_major = torch.empty(
962+
(input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device
963+
)
964+
fp8_output_col_major_t = torch.empty_like(
965+
hp_tensor.t(),
966+
dtype=fp8_dtype,
967+
device=hp_tensor.device,
968+
)
969+
970+
# launch triton kernel to perform conversion
971+
grid = lambda meta: (
972+
triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]),
973+
triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]),
974+
)
975+
_to_fp8_col_major_t_and_non_t[grid](
976+
hp_tensor,
977+
fp8_output_col_major,
978+
fp8_output_col_major_t,
979+
scale,
980+
num_elements,
981+
fp8_dtype_min,
982+
fp8_dtype_max,
983+
input_num_rows,
984+
input_num_cols,
985+
hp_tensor.stride(0),
986+
hp_tensor.stride(1),
987+
fp8_output_col_major.stride(0),
988+
fp8_output_col_major.stride(1),
989+
fp8_output_col_major_t.stride(0),
990+
fp8_output_col_major_t.stride(1),
991+
input_dtype=tl_input_dtype,
992+
output_dtype=tl_output_dtype,
993+
EPS=EPS,
994+
)
995+
996+
# for col major we need to update the strides to reflect the new memory layout
997+
col_major_strides = (1, input_num_rows)
998+
fp8_output_col_major = fp8_output_col_major.as_strided(
999+
fp8_output_col_major.size(), col_major_strides
1000+
)
1001+
1002+
# wrap outputs in Float8Tensors
1003+
fp8_tensor_col_major = Float8Tensor(
1004+
fp8_output_col_major,
1005+
scale,
1006+
orig_dtype=hp_tensor.dtype,
1007+
linear_mm_config=linear_mm_config,
1008+
gemm_input_role=gemm_input_role,
1009+
)
1010+
fp8_tensor_col_major_t = Float8Tensor(
1011+
fp8_output_col_major_t,
1012+
scale,
1013+
orig_dtype=hp_tensor.dtype,
1014+
linear_mm_config=linear_mm_config,
1015+
gemm_input_role=gemm_input_role,
1016+
)
1017+
return fp8_tensor_col_major, fp8_tensor_col_major_t
1018+
1019+
8621020
def _hp_tensor_to_scale(
8631021
hp_tensor: torch.Tensor,
8641022
tl_input_dtype: tl.core.dtype,

torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py

+76
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
KernelAlgorithm,
99
hp_to_fp8_col_major,
1010
hp_to_fp8_col_major_t,
11+
hp_to_fp8_col_major_t_and_non_t,
1112
hp_to_fp8_row_and_col_major,
1213
hp_to_fp8_row_major,
1314
hp_to_fp8_row_major_t,
@@ -410,3 +411,78 @@ def test_fp8_hp_to_fp8_row_major_t_and_non_t(
410411
torch.float8_e4m3fn,
411412
LinearMMConfig(),
412413
)
414+
415+
416+
@pytest.mark.parametrize(
417+
"algo",
418+
[KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX],
419+
)
420+
@pytest.mark.parametrize(
421+
"input_shape",
422+
[(2, 4), (32, 16), (512, 512)],
423+
)
424+
def test_fp8_hp_to_fp8_col_major_t_and_non_t(
425+
input_shape: tuple[int, int], algo: KernelAlgorithm
426+
):
427+
assert torch.cuda.is_available()
428+
device = "cuda"
429+
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
430+
x_bf16 = input_bf16.clone().detach().to(device)
431+
y_bf16 = input_bf16.clone().detach().to(device)
432+
433+
# production implementation
434+
x_fp8_row_major = hp_tensor_to_float8_dynamic(
435+
x_bf16,
436+
torch.float8_e4m3fn,
437+
LinearMMConfig(),
438+
)
439+
x_fp8_col_major = x_fp8_row_major.t().contiguous().t()
440+
x_fp8_col_major_t = x_fp8_row_major.t()
441+
442+
# float8nocompile triton implementation
443+
y_fp8_col_major, y_fp8_col_major_t = hp_to_fp8_col_major_t_and_non_t(
444+
y_bf16,
445+
torch.float8_e4m3fn,
446+
LinearMMConfig(),
447+
algo=algo,
448+
)
449+
450+
# check scales
451+
assert torch.eq(x_fp8_col_major._scale, y_fp8_col_major._scale)
452+
assert torch.eq(x_fp8_col_major_t._scale, y_fp8_col_major_t._scale)
453+
454+
# check data
455+
assert torch.all(torch.eq(x_fp8_col_major._data, y_fp8_col_major._data))
456+
assert torch.all(torch.eq(x_fp8_col_major_t._data, y_fp8_col_major_t._data))
457+
458+
# check shapes
459+
assert x_fp8_col_major.shape == y_fp8_col_major.shape
460+
assert x_fp8_col_major_t.shape == y_fp8_col_major_t.shape
461+
462+
# check strides
463+
assert x_fp8_col_major.stride() == y_fp8_col_major.stride()
464+
assert x_fp8_col_major_t.stride() == y_fp8_col_major_t.stride()
465+
466+
# check memory layout
467+
assert not is_row_major(x_fp8_col_major.stride())
468+
assert not is_row_major(y_fp8_col_major.stride())
469+
assert not is_row_major(x_fp8_col_major_t.stride())
470+
assert not is_row_major(y_fp8_col_major_t.stride())
471+
472+
# check underlying memory layout
473+
assert (
474+
x_fp8_col_major._data.storage().tolist()
475+
== y_fp8_col_major._data.storage().tolist()
476+
)
477+
assert (
478+
x_fp8_col_major_t._data.storage().tolist()
479+
== y_fp8_col_major_t._data.storage().tolist()
480+
)
481+
482+
# assert that error is raised when input tensor is not contiguous
483+
with pytest.raises(AssertionError, match="tensor must be contiguous"):
484+
hp_to_fp8_col_major_t_and_non_t(
485+
y_bf16.t(), # transpose so tensor memory layout is no longer contiguous
486+
torch.float8_e4m3fn,
487+
LinearMMConfig(),
488+
)

0 commit comments

Comments
 (0)