Skip to content

Commit c6c388b

Browse files
authored
float8 training: make the "config from recipe" API polished (#1731)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 7fc8ad4 commit c6c388b

File tree

7 files changed

+97
-95
lines changed

7 files changed

+97
-95
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
ScalingType,
6464
convert_to_float8_training,
6565
)
66-
from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config
6766
from torchao.float8.roofline_utils import (
6867
get_float8_mem_sympy,
6968
get_gemm_time_sympy,
@@ -349,7 +348,7 @@ def run(
349348

350349
# get the float8 dynamic axiswise scaling gpu kernel time
351350
torch._dynamo.reset()
352-
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
351+
config = Float8LinearConfig.from_recipe_name("rowwise")
353352
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
354353
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
355354
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
@@ -358,7 +357,7 @@ def run(
358357
# TODO(future PR): enable below once basic performance issues
359358
# are fixed
360359
# torch._dynamo.reset()
361-
# config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP)
360+
# config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp")
362361
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
363362
# m_fp8_lw = torch.compile(m_fp8_lw)
364363
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

benchmarks/float8/profile_linear_float8.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@
3939

4040
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
4141
from torchao.float8.config import (
42-
Float8LinearRecipeName,
42+
Float8LinearConfig,
4343
ScalingType,
44-
recipe_name_to_linear_config,
4544
)
4645
from torchao.float8.float8_linear_utils import (
4746
convert_to_float8_training,
@@ -311,8 +310,7 @@ def main(
311310
emulate=False,
312311
)
313312
elif recipe_name is not None:
314-
recipe_name = Float8LinearRecipeName(recipe_name)
315-
config = recipe_name_to_linear_config(recipe_name)
313+
config = Float8LinearConfig.from_recipe_name(recipe_name)
316314

317315
scaling_repr = "_".join(
318316
[

test/float8/test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
ScalingType,
3333
e4m3_dtype,
3434
e5m2_dtype,
35-
recipe_name_to_linear_config,
3635
)
3736
from torchao.float8.float8_linear import Float8Linear
3837
from torchao.float8.float8_linear_utils import (
@@ -442,7 +441,7 @@ def test_linear_from_recipe(
442441
linear_dtype = torch.bfloat16
443442
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
444443
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
445-
config = recipe_name_to_linear_config(recipe_name)
444+
config = Float8LinearConfig.from_recipe_name(recipe_name)
446445
self._test_linear_impl(
447446
x,
448447
m_ref,

test/float8/test_compile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
Float8LinearRecipeName,
3434
ScalingType,
3535
e4m3_dtype,
36-
recipe_name_to_linear_config,
3736
)
3837
from torchao.float8.float8_linear import Float8Linear
3938
from torchao.float8.float8_linear_utils import (
@@ -227,7 +226,7 @@ def test_inductor_from_config_params(
227226
)
228227
def test_inductor_from_recipe(recipe_name):
229228
torch._dynamo.reset()
230-
config = recipe_name_to_linear_config(recipe_name)
229+
config = Float8LinearConfig.from_recipe_name(recipe_name)
231230
fullgraph = True
232231
dtype = torch.bfloat16
233232
_test_compile_base(

test/float8/test_dtensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
Float8LinearRecipeName,
4242
ScalingType,
4343
e4m3_dtype,
44-
recipe_name_to_linear_config,
4544
)
4645
from torchao.float8.float8_linear_utils import convert_to_float8_training
4746
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
@@ -198,7 +197,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
198197
device = mesh.device_type
199198

200199
if rowwise:
201-
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
200+
config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
202201
# hack around config being frozen
203202
# TODO(future PR): we should make this nicer at the config level
204203
object.__setattr__(config, "emulate", True)

test/float8/test_numerics_integration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
Float8LinearConfig,
2929
Float8LinearRecipeName,
3030
ScalingType,
31-
recipe_name_to_linear_config,
3231
)
3332
from torchao.float8.float8_linear_utils import (
3433
convert_to_float8_training,
@@ -210,7 +209,7 @@ def test_encoder_fw_bw_from_recipe(
210209
self,
211210
recipe_name: str,
212211
):
213-
config = recipe_name_to_linear_config(recipe_name)
212+
config = Float8LinearConfig.from_recipe_name(recipe_name)
214213
self._test_impl(config)
215214

216215

torchao/float8/config.py

Lines changed: 89 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import enum
88
import logging
99
from dataclasses import dataclass
10-
from typing import Optional
10+
from typing import Optional, Union
1111

1212
import torch
1313

@@ -146,6 +146,32 @@ class Float8GemmConfig:
146146
use_fast_accum: bool = False
147147

148148

149+
# Pre-made recipes for common configurations
150+
class Float8LinearRecipeName(enum.Enum):
151+
152+
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
153+
TENSORWISE = "tensorwise"
154+
155+
# dynamic rowwise scaling with the CUTLASS rowwise kernel
156+
# * e4m3 for activations, weights, gradients
157+
# * scales rounded (floor) to the nearest power of two for increased accuracy
158+
ROWWISE = "rowwise"
159+
160+
# lw's recipe for a modification on rowwise scaling:
161+
#
162+
# output_hp = input_fp8_rowwise_dim0 @ weight_t_rowwise_dim1
163+
# grad_input_hp = grad_output_fp8_rowwise_dim0 @ weight_fp8_tensorwise
164+
# grad_weight_hp = input_t_hp @ grad_output_hp
165+
#
166+
# key characteristics:
167+
# * increased accuracy for grad_weight
168+
# * `input`, `weight` and `grad_output` now only need to be scaled
169+
# rowwise across a single dim compared to vanilla rowwise,
170+
# which is more amenable to fast kernels
171+
# * the e4m3 dtype is used across the board, including for gradients
172+
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"
173+
174+
149175
@dataclass(frozen=True)
150176
class Float8LinearConfig:
151177
"""
@@ -321,86 +347,69 @@ def __post_init__(self):
321347
"Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details."
322348
)
323349

350+
@staticmethod
351+
def from_recipe_name(
352+
recipe_name: Union[Float8LinearRecipeName, str],
353+
) -> "Float8LinearConfig":
354+
"""
355+
Input: `Float8LinearRecipeName` value, or a string representing a `Float8LinearRecipeName` value
356+
Output: a `Float8LinearConfig` configured to implement the specified recipe
357+
"""
358+
if type(recipe_name) == str:
359+
valid_names = [n.value for n in Float8LinearRecipeName]
360+
assert (
361+
recipe_name in valid_names
362+
), f"recipe_name {recipe_name} not in valid names {valid_names}"
363+
recipe_name = Float8LinearRecipeName(recipe_name)
324364

325-
# Pre-made recipes for common configurations
326-
# TODO(future PR): go through a round of design on this, and eventually expose
327-
# as a top level public API.
328-
class Float8LinearRecipeName(enum.Enum):
329-
TENSORWISE = "tensorwise"
330-
ROWWISE = "rowwise"
331-
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"
365+
if recipe_name is Float8LinearRecipeName.TENSORWISE:
366+
return Float8LinearConfig()
367+
368+
elif recipe_name is Float8LinearRecipeName.ROWWISE:
369+
cc_i = CastConfig(
370+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
371+
)
372+
cc_w = CastConfig(
373+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
374+
)
375+
cc_go = CastConfig(
376+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
377+
)
332378

379+
return Float8LinearConfig(
380+
cast_config_input=cc_i,
381+
cast_config_weight=cc_w,
382+
cast_config_grad_output=cc_go,
383+
# enable power of 2 scaling factors by default for row-wise scaling
384+
round_scales_to_power_of_2=True,
385+
)
333386

334-
def recipe_name_to_linear_config(
335-
recipe_name: Float8LinearRecipeName,
336-
) -> Float8LinearConfig:
337-
"""
338-
Input: `Float8LinearRecipeName` value
339-
Output: a `Float8LinearConfig` configured to implement the recipe
340-
"""
387+
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
341388

342-
if recipe_name is Float8LinearRecipeName.TENSORWISE:
343-
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
344-
return Float8LinearConfig()
345-
346-
elif recipe_name is Float8LinearRecipeName.ROWWISE:
347-
# dynamic axiswise scaling with the CUTLASS rowwise kernel
348-
cc_i = CastConfig(
349-
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
350-
)
351-
cc_w = CastConfig(
352-
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
353-
)
354-
cc_go = CastConfig(
355-
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
356-
)
357-
358-
return Float8LinearConfig(
359-
cast_config_input=cc_i,
360-
cast_config_weight=cc_w,
361-
cast_config_grad_output=cc_go,
362-
# enable power of 2 scaling factors by default for row-wise scaling
363-
round_scales_to_power_of_2=True,
364-
)
365-
366-
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
367-
# lw's recipe for a modification on all-axiswise:
368-
#
369-
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
370-
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
371-
# grad_weight_hp = input_t_hp @ grad_output_hp
372-
#
373-
# key characteristics:
374-
# * increased accuracy for grad_weight
375-
# * `input`, `weight` and `grad_output` now only need to be scaled
376-
# axiswise across a single dim compared to vanilla all-axiswise,
377-
# which is more amenable to fast kernels
378-
# * the e4m3 dtype is used across the board, including for gradients
379-
380-
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
381-
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
382-
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
383-
384-
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
385-
cc_go = CastConfig(
386-
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
387-
)
388-
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
389-
390-
# grad_weight_hp = input_t_hp @ grad_output_hp
391-
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
392-
cc_go_gw = CastConfig(
393-
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
394-
)
395-
396-
return Float8LinearConfig(
397-
cast_config_input=cc_i,
398-
cast_config_weight=cc_w,
399-
cast_config_grad_output=cc_go,
400-
cast_config_input_for_grad_weight=cc_i_gw,
401-
cast_config_weight_for_grad_input=cc_w_gi,
402-
cast_config_grad_output_for_grad_weight=cc_go_gw,
403-
)
404-
405-
else:
406-
raise AssertionError(f"unknown recipe_name {recipe_name}")
389+
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
390+
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
391+
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
392+
393+
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
394+
cc_go = CastConfig(
395+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
396+
)
397+
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
398+
399+
# grad_weight_hp = input_t_hp @ grad_output_hp
400+
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
401+
cc_go_gw = CastConfig(
402+
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
403+
)
404+
405+
return Float8LinearConfig(
406+
cast_config_input=cc_i,
407+
cast_config_weight=cc_w,
408+
cast_config_grad_output=cc_go,
409+
cast_config_input_for_grad_weight=cc_i_gw,
410+
cast_config_weight_for_grad_input=cc_w_gi,
411+
cast_config_grad_output_for_grad_weight=cc_go_gw,
412+
)
413+
414+
else:
415+
raise AssertionError(f"unknown recipe_name {recipe_name}")

0 commit comments

Comments
 (0)