Skip to content

Commit 070345d

Browse files
add fused transpose and non-transpose kernel and use it for grad output (#1497)
1 parent 4996101 commit 070345d

File tree

6 files changed

+275
-19
lines changed

6 files changed

+275
-19
lines changed

Diff for: torchao/prototype/float8nocompile/__init__.py

Whitespace-only changes.

Diff for: torchao/prototype/float8nocompile/float8nocompile_linear.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
ToFP8ColumnMajor,
1717
ToFP8ColumnMajorT,
1818
ToFP8RowAndColumnMajor,
19-
ToFP8RowMajor,
20-
ToFP8RowMajorT,
19+
ToFP8RowMajorTAndNonT,
2120
)
2221
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
2322
KernelAlgorithm,
@@ -138,12 +137,14 @@ def backward(ctx, grad_output):
138137
input_fp8_col_major, weight_hp = ctx.saved_tensors
139138

140139
# cast grad output to float8_e5m2 for backward
141-
grad_output_fp8_row_major = ToFP8RowMajor.apply(
142-
grad_output,
143-
ctx.config.cast_config_grad_output.target_dtype,
144-
ctx.linear_mm_config,
145-
GemmInputRole.GRAD_OUTPUT,
146-
ctx.kernel_algo,
140+
grad_output_fp8_row_major, grad_output_t_row_major = (
141+
ToFP8RowMajorTAndNonT.apply(
142+
grad_output,
143+
ctx.config.cast_config_grad_output.target_dtype,
144+
ctx.linear_mm_config,
145+
GemmInputRole.GRAD_OUTPUT,
146+
ctx.kernel_algo,
147+
)
147148
)
148149

149150
# grad_input = grad_output @ weight
@@ -159,12 +160,6 @@ def backward(ctx, grad_output):
159160
# grad_weight = grad_output_t @ input
160161
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
161162
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
162-
grad_output_t_row_major = ToFP8RowMajorT.apply(
163-
grad_output,
164-
ctx.config.cast_config_grad_output.target_dtype,
165-
ctx.linear_mm_config,
166-
GemmInputRole.GRAD_OUTPUT,
167-
ctx.kernel_algo,
168-
)
169163
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)
164+
170165
return grad_input, grad_weight, None, None, None

Diff for: torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,15 @@
1010

1111
import torch
1212

13-
from torchao.float8.float8_tensor import (
14-
GemmInputRole,
15-
LinearMMConfig,
16-
)
13+
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig
1714
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
1815
KernelAlgorithm,
1916
hp_to_fp8_col_major,
2017
hp_to_fp8_col_major_t,
2118
hp_to_fp8_row_and_col_major,
2219
hp_to_fp8_row_major,
2320
hp_to_fp8_row_major_t,
21+
hp_to_fp8_row_major_t_and_non_t,
2422
)
2523

2624

@@ -172,3 +170,33 @@ def forward(
172170
@staticmethod
173171
def backward(ctx, g):
174172
return g, None, None, None, None
173+
174+
175+
class ToFP8RowMajorTAndNonT(torch.autograd.Function):
176+
"""
177+
A differentiable conversion to fp8.
178+
* forward: convert from high precision to float8 and produces both row-major (transposed) and row-major (non-transposed) outputs
179+
* backward: pass the gradient without changes
180+
"""
181+
182+
@staticmethod
183+
def forward(
184+
ctx,
185+
tensor: torch.Tensor,
186+
float8_dtype: torch.dtype,
187+
linear_mm_config: LinearMMConfig,
188+
gemm_input_role: GemmInputRole,
189+
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
190+
):
191+
fp8_row_major, fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t(
192+
tensor,
193+
float8_dtype,
194+
linear_mm_config,
195+
gemm_input_role,
196+
algo=kernel_algo,
197+
)
198+
return fp8_row_major, fp8_row_major_t
199+
200+
@staticmethod
201+
def backward(ctx, g):
202+
return g, None, None, None, None

Diff for: torchao/prototype/float8nocompile/kernels/__init__.py

Whitespace-only changes.

Diff for: torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py

+158
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,82 @@ def _to_fp8_row_and_col_major(
305305
tl.store(col_major_out_ptr + col_major_offs, fp8_vals, mask=mask)
306306

307307

308+
@triton.autotune(
309+
configs=kernel_configs_2D,
310+
key=["num_elements"],
311+
)
312+
@triton.jit
313+
def _to_fp8_row_major_t_and_non_t(
314+
input_ptr,
315+
row_major_out_ptr,
316+
row_major_t_out_ptr,
317+
scale_ptr,
318+
num_elements: int,
319+
fp8_dtype_min: float,
320+
fp8_dtype_max: float,
321+
input_num_rows: int,
322+
input_num_cols: int,
323+
input_stride_row: int,
324+
input_stride_col: int,
325+
row_major_out_stride_row: int,
326+
row_major_out_stride_col: int,
327+
row_major_t_out_stride_row: int,
328+
row_major_t_out_stride_col: int,
329+
input_dtype: tl.constexpr,
330+
output_dtype: tl.constexpr,
331+
BLOCK_SIZE_ROWS: tl.constexpr,
332+
BLOCK_SIZE_COLS: tl.constexpr,
333+
EPS: tl.constexpr,
334+
):
335+
"""
336+
Reads a row-major, high precision input tensor and writes 2 output tensors:
337+
1) fp8 row major tensor (transposed)
338+
2) fp8 row major tensor
339+
"""
340+
block_row_id = tl.program_id(axis=0)
341+
block_col_id = tl.program_id(axis=1)
342+
343+
# load scaling factor
344+
scale = tl.load(scale_ptr).to(tl.float32)
345+
346+
# load block of input tensor
347+
block_row_start = block_row_id * BLOCK_SIZE_ROWS
348+
block_col_start = block_col_id * BLOCK_SIZE_COLS
349+
block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS)
350+
block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS)
351+
input_offs = (
352+
block_row_offs[:, None] * input_stride_row
353+
+ block_col_offs[None, :] * input_stride_col
354+
)
355+
mask = (block_row_offs[:, None] < input_num_rows) & (
356+
block_col_offs[None, :] < input_num_cols
357+
)
358+
vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype)
359+
360+
# perform conversion
361+
vals = vals * scale
362+
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
363+
364+
# write row-major output
365+
row_major_offs = (
366+
block_row_offs[:, None] * row_major_out_stride_row
367+
+ block_col_offs[None, :] * row_major_out_stride_col
368+
)
369+
tl.store(row_major_out_ptr + row_major_offs, fp8_vals, mask=mask)
370+
371+
# write tranposed row-major output
372+
row_major_t_num_rows = input_num_cols
373+
row_major_t_num_cols = input_num_rows
374+
row_major_t_offs = (
375+
block_col_offs[:, None] * row_major_t_out_stride_row
376+
+ block_row_offs[None, :] * row_major_t_out_stride_col
377+
)
378+
mask = (block_row_offs[:, None] < row_major_t_num_rows) & (
379+
block_col_offs[None, :] < row_major_t_num_cols
380+
)
381+
tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask)
382+
383+
308384
@triton.autotune(configs=kernel_configs_1D, key=["num_elements"])
309385
@triton.jit
310386
def _amax_atomic(
@@ -701,6 +777,88 @@ def hp_to_fp8_row_and_col_major(
701777
return fp8_tensor_row_major, fp8_tensor_col_major
702778

703779

780+
def hp_to_fp8_row_major_t_and_non_t(
781+
hp_tensor: torch.Tensor,
782+
fp8_dtype: torch.dtype,
783+
linear_mm_config: LinearMMConfig,
784+
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
785+
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
786+
) -> Float8Tensor:
787+
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"
788+
789+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
790+
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]
791+
792+
fp8_dtype_min = torch.finfo(fp8_dtype).min
793+
fp8_dtype_max = torch.finfo(fp8_dtype).max
794+
795+
# compute scaling factor for tensor
796+
scale = _hp_tensor_to_scale(
797+
hp_tensor,
798+
tl_input_dtype,
799+
fp8_dtype_max,
800+
algo,
801+
)
802+
803+
# perform fp8 conversion
804+
input_num_rows, input_num_cols = hp_tensor.shape
805+
transposed_num_rows, transposed_num_cols = input_num_cols, input_num_rows
806+
num_elements = hp_tensor.numel()
807+
808+
# preallocate necessary output tensors
809+
fp8_output_row_major = torch.empty(
810+
(input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device
811+
)
812+
fp8_output_row_major_t = torch.empty(
813+
(transposed_num_rows, transposed_num_cols),
814+
dtype=fp8_dtype,
815+
device=hp_tensor.device,
816+
)
817+
818+
# launch triton kernel to perform conversion
819+
grid = lambda meta: (
820+
triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]),
821+
triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]),
822+
)
823+
_to_fp8_row_major_t_and_non_t[grid](
824+
hp_tensor,
825+
fp8_output_row_major,
826+
fp8_output_row_major_t,
827+
scale,
828+
num_elements,
829+
fp8_dtype_min,
830+
fp8_dtype_max,
831+
input_num_rows,
832+
input_num_cols,
833+
hp_tensor.stride(0),
834+
hp_tensor.stride(1),
835+
fp8_output_row_major.stride(0),
836+
fp8_output_row_major.stride(1),
837+
fp8_output_row_major_t.stride(0),
838+
fp8_output_row_major_t.stride(1),
839+
input_dtype=tl_input_dtype,
840+
output_dtype=tl_output_dtype,
841+
EPS=EPS,
842+
)
843+
844+
# wrap outputs in Float8Tensors
845+
fp8_tensor_row_major = Float8Tensor(
846+
fp8_output_row_major,
847+
scale,
848+
orig_dtype=hp_tensor.dtype,
849+
linear_mm_config=linear_mm_config,
850+
gemm_input_role=gemm_input_role,
851+
)
852+
fp8_tensor_row_major_t = Float8Tensor(
853+
fp8_output_row_major_t,
854+
scale,
855+
orig_dtype=hp_tensor.dtype,
856+
linear_mm_config=linear_mm_config,
857+
gemm_input_role=gemm_input_role,
858+
)
859+
return fp8_tensor_row_major, fp8_tensor_row_major_t
860+
861+
704862
def _hp_tensor_to_scale(
705863
hp_tensor: torch.Tensor,
706864
tl_input_dtype: tl.core.dtype,

Diff for: torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py

+75
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
hp_to_fp8_row_and_col_major,
1212
hp_to_fp8_row_major,
1313
hp_to_fp8_row_major_t,
14+
hp_to_fp8_row_major_t_and_non_t,
1415
)
1516

1617

@@ -335,3 +336,77 @@ def test_fp8_hp_to_fp8_row_and_col_major(
335336
torch.float8_e4m3fn,
336337
LinearMMConfig(),
337338
)
339+
340+
341+
@pytest.mark.parametrize(
342+
"algo",
343+
[KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX],
344+
)
345+
@pytest.mark.parametrize(
346+
"input_shape",
347+
[(2, 4), (32, 16), (512, 512)],
348+
)
349+
def test_fp8_hp_to_fp8_row_major_t_and_non_t(
350+
input_shape: tuple[int, int], algo: KernelAlgorithm
351+
):
352+
assert torch.cuda.is_available()
353+
device = "cuda"
354+
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
355+
x_bf16 = input_bf16.clone().detach().to(device)
356+
y_bf16 = input_bf16.clone().detach().to(device)
357+
358+
# production implementation
359+
x_fp8_row_major = hp_tensor_to_float8_dynamic(
360+
x_bf16,
361+
torch.float8_e4m3fn,
362+
LinearMMConfig(),
363+
)
364+
x_fp8_row_major_t = x_fp8_row_major.t().contiguous()
365+
366+
# float8nocompile triton implementation
367+
y_fp8_row_major, y_fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t(
368+
y_bf16,
369+
torch.float8_e4m3fn,
370+
LinearMMConfig(),
371+
algo=algo,
372+
)
373+
374+
# check scales
375+
assert torch.eq(x_fp8_row_major._scale, y_fp8_row_major._scale)
376+
assert torch.eq(x_fp8_row_major_t._scale, y_fp8_row_major_t._scale)
377+
378+
# check data
379+
assert torch.all(torch.eq(x_fp8_row_major._data, y_fp8_row_major._data))
380+
assert torch.all(torch.eq(x_fp8_row_major_t._data, y_fp8_row_major_t._data))
381+
382+
# check shapes
383+
assert x_fp8_row_major.shape == y_fp8_row_major.shape
384+
assert x_fp8_row_major_t.shape == y_fp8_row_major_t.shape
385+
386+
# check strides
387+
assert x_fp8_row_major.stride() == y_fp8_row_major.stride()
388+
assert x_fp8_row_major_t.stride() == y_fp8_row_major_t.stride()
389+
390+
# check memory layout
391+
assert is_row_major(x_fp8_row_major.stride())
392+
assert is_row_major(y_fp8_row_major.stride())
393+
assert is_row_major(x_fp8_row_major_t.stride())
394+
assert is_row_major(y_fp8_row_major_t.stride())
395+
396+
# check underlying memory layout
397+
assert (
398+
x_fp8_row_major._data.storage().tolist()
399+
== y_fp8_row_major._data.storage().tolist()
400+
)
401+
assert (
402+
x_fp8_row_major_t._data.storage().tolist()
403+
== y_fp8_row_major_t._data.storage().tolist()
404+
)
405+
406+
# assert that error is raised when input tensor is not contiguous
407+
with pytest.raises(AssertionError, match="tensor must be contiguous"):
408+
hp_to_fp8_row_major_t_and_non_t(
409+
y_bf16.t(), # transpose so tensor memory layout is no longer contiguous
410+
torch.float8_e4m3fn,
411+
LinearMMConfig(),
412+
)

0 commit comments

Comments
 (0)