|
7 | 7 | import enum
|
8 | 8 | import logging
|
9 | 9 | from dataclasses import dataclass
|
10 |
| -from typing import Optional |
| 10 | +from typing import Optional, Union |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
@@ -146,6 +146,32 @@ class Float8GemmConfig:
|
146 | 146 | use_fast_accum: bool = False
|
147 | 147 |
|
148 | 148 |
|
| 149 | +# Pre-made recipes for common configurations |
| 150 | +class Float8LinearRecipeName(enum.Enum): |
| 151 | + |
| 152 | + # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel |
| 153 | + TENSORWISE = "tensorwise" |
| 154 | + |
| 155 | + # dynamic rowwise scaling with the CUTLASS rowwise kernel |
| 156 | + # * e4m3 for activations, weights, gradients |
| 157 | + # * scales rounded (floor) to the nearest power of two for increased accuracy |
| 158 | + ROWWISE = "rowwise" |
| 159 | + |
| 160 | + # lw's recipe for a modification on rowwise scaling: |
| 161 | + # |
| 162 | + # output_hp = input_fp8_rowwise_dim0 @ weight_t_rowwise_dim1 |
| 163 | + # grad_input_hp = grad_output_fp8_rowwise_dim0 @ weight_fp8_tensorwise |
| 164 | + # grad_weight_hp = input_t_hp @ grad_output_hp |
| 165 | + # |
| 166 | + # key characteristics: |
| 167 | + # * increased accuracy for grad_weight |
| 168 | + # * `input`, `weight` and `grad_output` now only need to be scaled |
| 169 | + # rowwise across a single dim compared to vanilla rowwise, |
| 170 | + # which is more amenable to fast kernels |
| 171 | + # * the e4m3 dtype is used across the board, including for gradients |
| 172 | + ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" |
| 173 | + |
| 174 | + |
149 | 175 | @dataclass(frozen=True)
|
150 | 176 | class Float8LinearConfig:
|
151 | 177 | """
|
@@ -321,86 +347,69 @@ def __post_init__(self):
|
321 | 347 | "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details."
|
322 | 348 | )
|
323 | 349 |
|
| 350 | + @staticmethod |
| 351 | + def from_recipe_name( |
| 352 | + recipe_name: Union[Float8LinearRecipeName, str], |
| 353 | + ) -> "Float8LinearConfig": |
| 354 | + """ |
| 355 | + Input: `Float8LinearRecipeName` value, or a string representing a `Float8LinearRecipeName` value |
| 356 | + Output: a `Float8LinearConfig` configured to implement the specified recipe |
| 357 | + """ |
| 358 | + if type(recipe_name) == str: |
| 359 | + valid_names = [n.value for n in Float8LinearRecipeName] |
| 360 | + assert ( |
| 361 | + recipe_name in valid_names |
| 362 | + ), f"recipe_name {recipe_name} not in valid names {valid_names}" |
| 363 | + recipe_name = Float8LinearRecipeName(recipe_name) |
324 | 364 |
|
325 |
| -# Pre-made recipes for common configurations |
326 |
| -# TODO(future PR): go through a round of design on this, and eventually expose |
327 |
| -# as a top level public API. |
328 |
| -class Float8LinearRecipeName(enum.Enum): |
329 |
| - TENSORWISE = "tensorwise" |
330 |
| - ROWWISE = "rowwise" |
331 |
| - ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" |
| 365 | + if recipe_name is Float8LinearRecipeName.TENSORWISE: |
| 366 | + return Float8LinearConfig() |
| 367 | + |
| 368 | + elif recipe_name is Float8LinearRecipeName.ROWWISE: |
| 369 | + cc_i = CastConfig( |
| 370 | + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
| 371 | + ) |
| 372 | + cc_w = CastConfig( |
| 373 | + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
| 374 | + ) |
| 375 | + cc_go = CastConfig( |
| 376 | + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
| 377 | + ) |
332 | 378 |
|
| 379 | + return Float8LinearConfig( |
| 380 | + cast_config_input=cc_i, |
| 381 | + cast_config_weight=cc_w, |
| 382 | + cast_config_grad_output=cc_go, |
| 383 | + # enable power of 2 scaling factors by default for row-wise scaling |
| 384 | + round_scales_to_power_of_2=True, |
| 385 | + ) |
333 | 386 |
|
334 |
| -def recipe_name_to_linear_config( |
335 |
| - recipe_name: Float8LinearRecipeName, |
336 |
| -) -> Float8LinearConfig: |
337 |
| - """ |
338 |
| - Input: `Float8LinearRecipeName` value |
339 |
| - Output: a `Float8LinearConfig` configured to implement the recipe |
340 |
| - """ |
| 387 | + elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: |
341 | 388 |
|
342 |
| - if recipe_name is Float8LinearRecipeName.TENSORWISE: |
343 |
| - # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel |
344 |
| - return Float8LinearConfig() |
345 |
| - |
346 |
| - elif recipe_name is Float8LinearRecipeName.ROWWISE: |
347 |
| - # dynamic axiswise scaling with the CUTLASS rowwise kernel |
348 |
| - cc_i = CastConfig( |
349 |
| - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
350 |
| - ) |
351 |
| - cc_w = CastConfig( |
352 |
| - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
353 |
| - ) |
354 |
| - cc_go = CastConfig( |
355 |
| - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
356 |
| - ) |
357 |
| - |
358 |
| - return Float8LinearConfig( |
359 |
| - cast_config_input=cc_i, |
360 |
| - cast_config_weight=cc_w, |
361 |
| - cast_config_grad_output=cc_go, |
362 |
| - # enable power of 2 scaling factors by default for row-wise scaling |
363 |
| - round_scales_to_power_of_2=True, |
364 |
| - ) |
365 |
| - |
366 |
| - elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: |
367 |
| - # lw's recipe for a modification on all-axiswise: |
368 |
| - # |
369 |
| - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 |
370 |
| - # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise |
371 |
| - # grad_weight_hp = input_t_hp @ grad_output_hp |
372 |
| - # |
373 |
| - # key characteristics: |
374 |
| - # * increased accuracy for grad_weight |
375 |
| - # * `input`, `weight` and `grad_output` now only need to be scaled |
376 |
| - # axiswise across a single dim compared to vanilla all-axiswise, |
377 |
| - # which is more amenable to fast kernels |
378 |
| - # * the e4m3 dtype is used across the board, including for gradients |
379 |
| - |
380 |
| - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 |
381 |
| - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) |
382 |
| - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) |
383 |
| - |
384 |
| - # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise |
385 |
| - cc_go = CastConfig( |
386 |
| - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
387 |
| - ) |
388 |
| - cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) |
389 |
| - |
390 |
| - # grad_weight_hp = input_t_hp @ grad_output_hp |
391 |
| - cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) |
392 |
| - cc_go_gw = CastConfig( |
393 |
| - scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype |
394 |
| - ) |
395 |
| - |
396 |
| - return Float8LinearConfig( |
397 |
| - cast_config_input=cc_i, |
398 |
| - cast_config_weight=cc_w, |
399 |
| - cast_config_grad_output=cc_go, |
400 |
| - cast_config_input_for_grad_weight=cc_i_gw, |
401 |
| - cast_config_weight_for_grad_input=cc_w_gi, |
402 |
| - cast_config_grad_output_for_grad_weight=cc_go_gw, |
403 |
| - ) |
404 |
| - |
405 |
| - else: |
406 |
| - raise AssertionError(f"unknown recipe_name {recipe_name}") |
| 389 | + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 |
| 390 | + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) |
| 391 | + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) |
| 392 | + |
| 393 | + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise |
| 394 | + cc_go = CastConfig( |
| 395 | + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype |
| 396 | + ) |
| 397 | + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) |
| 398 | + |
| 399 | + # grad_weight_hp = input_t_hp @ grad_output_hp |
| 400 | + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) |
| 401 | + cc_go_gw = CastConfig( |
| 402 | + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype |
| 403 | + ) |
| 404 | + |
| 405 | + return Float8LinearConfig( |
| 406 | + cast_config_input=cc_i, |
| 407 | + cast_config_weight=cc_w, |
| 408 | + cast_config_grad_output=cc_go, |
| 409 | + cast_config_input_for_grad_weight=cc_i_gw, |
| 410 | + cast_config_weight_for_grad_input=cc_w_gi, |
| 411 | + cast_config_grad_output_for_grad_weight=cc_go_gw, |
| 412 | + ) |
| 413 | + |
| 414 | + else: |
| 415 | + raise AssertionError(f"unknown recipe_name {recipe_name}") |
0 commit comments