Skip to content

Commit 7fc8ad4

Browse files
authored
float8 training: clean up recipe names (#1730)
Update [ghstack-poisoned]
1 parent c59561a commit 7fc8ad4

File tree

6 files changed

+15
-15
lines changed

6 files changed

+15
-15
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def run(
349349

350350
# get the float8 dynamic axiswise scaling gpu kernel time
351351
torch._dynamo.reset()
352-
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
352+
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
353353
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
354354
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
355355
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
@@ -358,7 +358,7 @@ def run(
358358
# TODO(future PR): enable below once basic performance issues
359359
# are fixed
360360
# torch._dynamo.reset()
361-
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
361+
# config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP)
362362
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
363363
# m_fp8_lw = torch.compile(m_fp8_lw)
364364
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

test/float8/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,8 @@ def test_linear_from_config_params(
420420
@pytest.mark.parametrize(
421421
"recipe_name",
422422
[
423-
Float8LinearRecipeName.ALL_AXISWISE,
424-
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
423+
Float8LinearRecipeName.ROWWISE,
424+
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
425425
],
426426
)
427427
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])

test/float8/test_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def test_inductor_from_config_params(
218218
@pytest.mark.parametrize(
219219
"recipe_name",
220220
[
221-
Float8LinearRecipeName.ALL_AXISWISE,
222-
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
221+
Float8LinearRecipeName.ROWWISE,
222+
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
223223
],
224224
)
225225
@unittest.skipIf(

test/float8/test_dtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
198198
device = mesh.device_type
199199

200200
if rowwise:
201-
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
201+
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
202202
# hack around config being frozen
203203
# TODO(future PR): we should make this nicer at the config level
204204
object.__setattr__(config, "emulate", True)

test/float8/test_numerics_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def test_encoder_fw_bw_from_config_params(
198198
@pytest.mark.parametrize(
199199
"recipe_name",
200200
[
201-
Float8LinearRecipeName.ALL_AXISWISE,
202-
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
201+
Float8LinearRecipeName.ROWWISE,
202+
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
203203
],
204204
)
205205
@pytest.mark.skipif(

torchao/float8/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ def __post_init__(self):
326326
# TODO(future PR): go through a round of design on this, and eventually expose
327327
# as a top level public API.
328328
class Float8LinearRecipeName(enum.Enum):
329-
ALL_TENSORWISE = "all_tensorwise"
330-
ALL_AXISWISE = "all_axiswise"
331-
LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp"
329+
TENSORWISE = "tensorwise"
330+
ROWWISE = "rowwise"
331+
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"
332332

333333

334334
def recipe_name_to_linear_config(
@@ -339,11 +339,11 @@ def recipe_name_to_linear_config(
339339
Output: a `Float8LinearConfig` configured to implement the recipe
340340
"""
341341

342-
if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE:
342+
if recipe_name is Float8LinearRecipeName.TENSORWISE:
343343
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
344344
return Float8LinearConfig()
345345

346-
elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
346+
elif recipe_name is Float8LinearRecipeName.ROWWISE:
347347
# dynamic axiswise scaling with the CUTLASS rowwise kernel
348348
cc_i = CastConfig(
349349
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
@@ -363,7 +363,7 @@ def recipe_name_to_linear_config(
363363
round_scales_to_power_of_2=True,
364364
)
365365

366-
elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
366+
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
367367
# lw's recipe for a modification on all-axiswise:
368368
#
369369
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1

0 commit comments

Comments
 (0)