Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b9b606e

Browse files
vkuzofacebook-github-bot
authored andcommitted
add per-gemm config to Float8LinearConfig (#334)
Summary: Pull Request resolved: #334 Previously the per-gemm configuration had to be hardcoded in library code. This PR exposes it to the top-level UX by adding a `Float8GemmConfig` field to `Float8LinearConfig`. Note that today the only supported configuration option is `use_fast_accum`. In the future, configuring output_dtype and whether to keep a gemm in higher precision would go here. Reviewed By: weifengpy Differential Revision: D60252069 fbshipit-source-id: bca34eb49e1bf046f937e32b11b2369b535d56e6
1 parent ed1693e commit b9b606e

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

float8_experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Lets define a few top level things here
77
from float8_experimental.config import (
88
DelayedScalingConfig,
9+
Float8GemmConfig,
910
Float8LinearConfig,
1011
Float8TensorCastConfig,
1112
TensorScalingType,
@@ -33,6 +34,7 @@
3334
# configuration
3435
"DelayedScalingConfig",
3536
"TensorScalingType",
37+
"Float8GemmConfig",
3638
"Float8LinearConfig",
3739
"Float8TensorCastConfig",
3840
# top level UX

float8_experimental/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ def __post_init__(self):
5353
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
5454

5555

56+
@dataclass(frozen=True)
57+
class Float8GemmConfig:
58+
"""
59+
Configuration for a float8 gemm.
60+
"""
61+
62+
# If True, fast accumulation in lower precision is used.
63+
# Note: this flag is currently a no-op if emulation is turned on.
64+
use_fast_accum: bool = False
65+
66+
5667
@dataclass(frozen=True)
5768
class Float8LinearConfig:
5869
"""
@@ -67,6 +78,14 @@ class Float8LinearConfig:
6778
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
6879
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()
6980

81+
#
82+
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
83+
# `grad_weight`
84+
#
85+
gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
86+
gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
87+
gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig()
88+
7089
#
7190
# Per-linear configuration
7291
#

float8_experimental/float8_linear.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs):
168168

169169
self.create_buffers()
170170

171-
# TODO(future): user level configuration of gemms
172171
self.linear_mm_config = LinearMMConfig(
173-
# input
172+
# output
174173
ScaledMMConfig(
175174
emulate,
176-
True if not emulate else False,
175+
self.config.gemm_config_output.use_fast_accum,
177176
False,
178177
self.config.pad_inner_dim,
179178
),
180-
# weight
179+
# grad_input
181180
ScaledMMConfig(
182181
emulate,
183-
True if not emulate else False,
182+
self.config.gemm_config_grad_input.use_fast_accum,
183+
False,
184+
self.config.pad_inner_dim,
185+
),
186+
# grad_weight
187+
ScaledMMConfig(
188+
emulate,
189+
self.config.gemm_config_grad_weight.use_fast_accum,
184190
False,
185191
self.config.pad_inner_dim,
186192
),
187-
# grad_output
188-
ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim),
189193
)
190194

191195
# Note: is_amax_initialized is not a buffer to avoid data dependent

0 commit comments

Comments
 (0)