Skip to content

Commit 5a0d662

Browse files
add torch autograd funcs wrapping new fp8 conversion kernels (#1495)
1 parent cd62874 commit 5a0d662

File tree

1 file changed

+113
-57
lines changed

1 file changed

+113
-57
lines changed

Diff for: torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py

+113-57
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,81 @@
1616
GemmInputRole,
1717
LinearMMConfig,
1818
)
19+
1920
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
21+
hp_to_fp8_col_major,
22+
hp_to_fp8_col_major_t,
23+
hp_to_fp8_row_and_col_major,
24+
hp_to_fp8_row_major,
25+
hp_to_fp8_row_major_t,
2026
KernelAlgorithm,
21-
triton_hp_tensor_to_float8_dynamic,
2227
)
2328

24-
# avoid division by zero when calculating scale
25-
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
26-
EPS = 1e-12
2729

30+
class ToFP8RowAndColumnMajor(torch.autograd.Function):
31+
"""
32+
A differentiable conversion to fp8.
33+
* forward: convert from high precision to float8 and produces both row-major and column-major outputs
34+
* backward: pass the gradient without changes
35+
"""
36+
37+
@staticmethod
38+
def forward(
39+
ctx,
40+
tensor: torch.Tensor,
41+
float8_dtype: torch.dtype,
42+
linear_mm_config: LinearMMConfig,
43+
gemm_input_role: GemmInputRole,
44+
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
45+
):
46+
fp8_row_major, fp8_col_major = hp_to_fp8_row_and_col_major(
47+
tensor,
48+
float8_dtype,
49+
linear_mm_config,
50+
gemm_input_role,
51+
algo=kernel_algo,
52+
)
53+
return fp8_row_major, fp8_col_major
2854

29-
def hp_tensor_to_float8nocompile_dynamic(
30-
hp_tensor: torch.Tensor,
31-
float8_dtype: torch.dtype,
32-
linear_mm_config: LinearMMConfig,
33-
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
34-
) -> Float8Tensor:
55+
@staticmethod
56+
def backward(ctx, g):
57+
return g, None, None, None, None
58+
59+
60+
class ToFP8RowMajor(torch.autograd.Function):
3561
"""
36-
Given a high precision tensor `hp_tensor`,
37-
scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.
38-
39-
Args:
40-
hp_tensor: the tensor to convert
41-
float8_dtype: the float8 dtype to use
42-
linear_mm_config: Defines the configuration for the scaled_mm for
43-
the 3 fwd/bwd gemms of linear
44-
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
45-
the 3 fwd/bwd gemms of linear
62+
A differentiable conversion to fp8 in row-major layout.
63+
* forward: convert from high precision to float8 with row-major memory layout
64+
* backward: pass the gradient without changes
4665
"""
47-
# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel
48-
# torch.compile and eager show different numerics for 1.0 / float32,
49-
# upcast to float64 to ensure same numeric between compile and eager
50-
amax = torch.max(torch.abs(hp_tensor)).to(torch.float64)
51-
scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
52-
scale = scale.to(torch.float32) # scale must be fp32
53-
return _ToFloat8ConstrFunc.apply(
54-
hp_tensor,
55-
scale,
56-
float8_dtype,
57-
linear_mm_config,
58-
gemm_input_role,
59-
None,
60-
)
61-
62-
63-
class Float8NoCompileConversionFunc(torch.autograd.Function):
66+
67+
@staticmethod
68+
def forward(
69+
ctx,
70+
tensor: torch.Tensor,
71+
float8_dtype: torch.dtype,
72+
linear_mm_config: LinearMMConfig,
73+
gemm_input_role: GemmInputRole,
74+
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
75+
):
76+
fp8_row_major = hp_to_fp8_row_major(
77+
tensor,
78+
float8_dtype,
79+
linear_mm_config,
80+
gemm_input_role,
81+
algo=kernel_algo,
82+
)
83+
return fp8_row_major
84+
85+
@staticmethod
86+
def backward(ctx, g):
87+
return g, None, None, None, None
88+
89+
90+
class ToFP8RowMajorT(torch.autograd.Function):
6491
"""
65-
A differentiable conversion to fp8.
66-
* forward: convert from high precision to float8
92+
A differentiable conversion to fp8 with transposed dimensions in row-major layout.
93+
* forward: convert from high precision to float8 with transposed dimensions with row-major memory layout
6794
* backward: pass the gradient without changes
6895
"""
6996

@@ -76,24 +103,25 @@ def forward(
76103
gemm_input_role: GemmInputRole,
77104
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
78105
):
79-
return triton_hp_tensor_to_float8_dynamic(
106+
fp8_row_major_t = hp_to_fp8_row_major_t(
80107
tensor,
81108
float8_dtype,
82109
linear_mm_config,
83110
gemm_input_role,
84111
algo=kernel_algo,
85112
)
113+
return fp8_row_major_t
86114

87115
@staticmethod
88116
def backward(ctx, g):
89-
return g, None, None, None, None, None
117+
return g, None, None, None, None
90118

91119

92-
class NoopFwToFloat8NoCompileBwDynamic(torch.autograd.Function):
120+
class ToFP8ColumnMajor(torch.autograd.Function):
93121
"""
94-
A differentiable conversion to fp8.
95-
* forward: no-op
96-
* backward: convert to float8 with tensor-wise dynamic scaling
122+
A differentiable conversion to fp8 in column-major layout.
123+
* forward: convert from high precision to float8 with column-major memory layout
124+
* backward: pass the gradient without changes
97125
"""
98126

99127
@staticmethod
@@ -102,20 +130,48 @@ def forward(
102130
tensor: torch.Tensor,
103131
float8_dtype: torch.dtype,
104132
linear_mm_config: LinearMMConfig,
133+
gemm_input_role: GemmInputRole,
105134
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
106135
):
107-
ctx.linear_mm_config = linear_mm_config
108-
ctx.target_dtype = float8_dtype
109-
ctx.kernel_algo = kernel_algo
110-
return tensor
136+
fp8_col_major = hp_to_fp8_col_major(
137+
tensor,
138+
float8_dtype,
139+
linear_mm_config,
140+
gemm_input_role,
141+
algo=kernel_algo,
142+
)
143+
return fp8_col_major
144+
145+
@staticmethod
146+
def backward(ctx, g):
147+
return g, None, None, None, None
148+
149+
150+
class ToFP8ColumnMajorT(torch.autograd.Function):
151+
"""
152+
A differentiable conversion to fp8 with transposed dimensions in column-major layout.
153+
* forward: convert from high precision to float8 with transposed dimensions in column-major memory layout.
154+
* backward: pass the gradient without changes
155+
"""
111156

112157
@staticmethod
113-
def backward(ctx, gradY):
114-
fp8_tensor = triton_hp_tensor_to_float8_dynamic(
115-
gradY,
116-
ctx.target_dtype,
117-
ctx.linear_mm_config,
118-
GemmInputRole.GRAD_OUTPUT,
119-
ctx.kernel_algo,
158+
def forward(
159+
ctx,
160+
tensor: torch.Tensor,
161+
float8_dtype: torch.dtype,
162+
linear_mm_config: LinearMMConfig,
163+
gemm_input_role: GemmInputRole,
164+
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
165+
):
166+
fp8_col_major_t = hp_to_fp8_col_major_t(
167+
tensor,
168+
float8_dtype,
169+
linear_mm_config,
170+
gemm_input_role,
171+
algo=kernel_algo,
120172
)
121-
return fp8_tensor, None, None, None
173+
return fp8_col_major_t
174+
175+
@staticmethod
176+
def backward(ctx, g):
177+
return g, None, None, None, None

0 commit comments

Comments
 (0)