Skip to content

Commit 585cdfe

Browse files
committed
add axiswise scaling to Float8Linear
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 77d62e8 ghstack-comment-id: 2368837904 Pull Request resolved: #920
1 parent be5a4a8 commit 585cdfe

File tree

8 files changed

+434
-39
lines changed

8 files changed

+434
-39
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
17+
from torchao.float8.config import (
18+
CastConfig,
19+
Float8LinearConfig,
20+
ScalingType,
21+
ScalingGranularity,
22+
)
1823
from torchao.float8.float8_linear import Float8Linear
1924
from torchao.float8.float8_linear_utils import (
2025
linear_requires_sync,
@@ -107,35 +112,49 @@ def main(
107112
scaling_type_input: str = "dynamic",
108113
scaling_type_weight: str = "dynamic",
109114
scaling_type_grad_output: str = "dynamic",
115+
scaling_granularity: str = "tensorwise",
110116
):
111117
device = "cuda"
112118
print(f"Compile is set to | {compile}")
113119

114120
scaling_type_input = ScalingType(scaling_type_input)
115121
scaling_type_weight = ScalingType(scaling_type_weight)
116122
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
123+
scaling_granularity = ScalingGranularity(scaling_granularity)
117124

118125
if scaling_type_input is ScalingType.STATIC:
119126
cast_config_input=CastConfig(
120127
scaling_type=scaling_type_input,
121128
static_scale=torch.tensor([1.0], device="cuda"),
129+
scaling_granularity=scaling_granularity,
122130
)
123131
else:
124-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
132+
cast_config_input=CastConfig(
133+
scaling_type=scaling_type_input,
134+
scaling_granularity=scaling_granularity,
135+
)
125136
if scaling_type_weight is ScalingType.STATIC:
126137
cast_config_weight=CastConfig(
127138
scaling_type=scaling_type_weight,
128139
static_scale=torch.tensor([1.0], device="cuda"),
140+
scaling_granularity=scaling_granularity,
129141
)
130142
else:
131-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
143+
cast_config_weight=CastConfig(
144+
scaling_type=scaling_type_weight,
145+
scaling_granularity=scaling_granularity,
146+
)
132147
if scaling_type_grad_output is ScalingType.STATIC:
133148
cast_config_grad_output=CastConfig(
134149
scaling_type=scaling_type_grad_output,
135150
static_scale=torch.tensor([1.0], device="cuda"),
151+
scaling_granularity=scaling_granularity,
136152
)
137153
else:
138-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
154+
cast_config_grad_output=CastConfig(
155+
scaling_type=scaling_type_grad_output,
156+
scaling_granularity=scaling_granularity,
157+
)
139158

140159
config = Float8LinearConfig(
141160
cast_config_input=cast_config_input,
@@ -167,7 +186,7 @@ def main(
167186
copy.deepcopy(linear_ref),
168187
config=config,
169188
)
170-
scaling_repr = linear_float8.scaling_repr()
189+
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
171190

172191
if fast_accum:
173192
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -310,6 +329,7 @@ def invoke_main() -> None:
310329
parser.add_argument("--scaling_type_input", type=str, required=False)
311330
parser.add_argument("--scaling_type_weight", type=str, required=False)
312331
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
332+
parser.add_argument("--scaling_granularity", type=str, required=False)
313333
args = parser.parse_args()
314334
output_path = Path(args.output_path) if args.output_path is not None else None
315335
kwargs = {}
@@ -327,6 +347,8 @@ def invoke_main() -> None:
327347
kwargs["scaling_type_weight"] = args.scaling_type_weight
328348
if args.scaling_type_grad_output is not None:
329349
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
350+
if args.scaling_granularity is not None:
351+
kwargs["scaling_granularity"] = args.scaling_granularity
330352
main(
331353
output_path,
332354
not args.disable_compile,

benchmarks/float8/profile_linear_float8.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
import torch.nn as nn
2424
import torch.nn.functional as F
25-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
25+
from torchao.float8.config import (
26+
CastConfig,
27+
Float8LinearConfig,
28+
ScalingType,
29+
ScalingGranularity,
30+
)
2631
from torchao.float8.float8_linear_utils import (
2732
convert_to_float8_training,
2833
linear_requires_sync,
@@ -252,6 +257,7 @@ def main(
252257
scaling_type_input: str = "dynamic",
253258
scaling_type_weight: str = "dynamic",
254259
scaling_type_grad_output: str = "dynamic",
260+
scaling_granularity: str = "tensorwise",
255261
model_type: str = "linear",
256262
dtype_filter: str = "both",
257263
add_inductor_metadata_to_trace: bool = True,
@@ -263,28 +269,41 @@ def main(
263269
scaling_type_input = ScalingType(scaling_type_input)
264270
scaling_type_weight = ScalingType(scaling_type_weight)
265271
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272+
scaling_granularity = ScalingGranularity(scaling_granularity)
266273

267274
if scaling_type_input is ScalingType.STATIC:
268275
cast_config_input=CastConfig(
269276
scaling_type=scaling_type_input,
270277
static_scale=torch.tensor([1.0], device="cuda"),
278+
scaling_granularity=scaling_granularity,
271279
)
272280
else:
273-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
281+
cast_config_input=CastConfig(
282+
scaling_type=scaling_type_input,
283+
scaling_granularity=scaling_granularity,
284+
)
274285
if scaling_type_weight is ScalingType.STATIC:
275286
cast_config_weight=CastConfig(
276287
scaling_type=scaling_type_weight,
277288
static_scale=torch.tensor([1.0], device="cuda"),
289+
scaling_granularity=scaling_granularity,
278290
)
279291
else:
280-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
292+
cast_config_weight=CastConfig(
293+
scaling_type=scaling_type_weight,
294+
scaling_granularity=scaling_granularity,
295+
)
281296
if scaling_type_grad_output is ScalingType.STATIC:
282297
cast_config_grad_output=CastConfig(
283298
scaling_type=scaling_type_grad_output,
284299
static_scale=torch.tensor([1.0], device="cuda"),
300+
scaling_granularity=scaling_granularity,
285301
)
286302
else:
287-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
303+
cast_config_grad_output=CastConfig(
304+
scaling_type=scaling_type_grad_output,
305+
scaling_granularity=scaling_granularity,
306+
)
288307

289308
config = Float8LinearConfig(
290309
cast_config_input=cast_config_input,

test/float8/test_base.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ def _test_linear_impl(
327327
"scaling_type_grad_output",
328328
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
329329
)
330+
@pytest.mark.parametrize(
331+
"scaling_granularity",
332+
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
333+
)
330334
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
331335
@pytest.mark.parametrize("linear_bias", [False, True])
332336
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@@ -337,6 +341,7 @@ def test_linear(
337341
scaling_type_input: ScalingType,
338342
scaling_type_weight: ScalingType,
339343
scaling_type_grad_output: ScalingType,
344+
scaling_granularity: ScalingGranularity,
340345
linear_dtype: torch.dtype,
341346
linear_bias: bool,
342347
):
@@ -349,30 +354,51 @@ def test_linear(
349354
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
350355
)
351356
pytest.skip()
357+
if scaling_granularity is ScalingGranularity.AXISWISE:
358+
if (
359+
scaling_type_input != ScalingType.DYNAMIC or
360+
scaling_type_weight != ScalingType.DYNAMIC or
361+
scaling_type_grad_output != ScalingType.DYNAMIC or
362+
linear_dtype != torch.bfloat16
363+
):
364+
pytest.skip()
365+
352366
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
353367
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
354368

355369
if scaling_type_input is ScalingType.STATIC:
356370
cast_config_input = CastConfig(
357371
scaling_type=scaling_type_input,
372+
scaling_granularity=scaling_granularity,
358373
static_scale=torch.tensor([1.0], device="cuda"),
359374
)
360375
else:
361-
cast_config_input = CastConfig(scaling_type=scaling_type_input)
376+
cast_config_input = CastConfig(
377+
scaling_type=scaling_type_input,
378+
scaling_granularity=scaling_granularity,
379+
)
362380
if scaling_type_weight is ScalingType.STATIC:
363381
cast_config_weight = CastConfig(
364382
scaling_type=scaling_type_weight,
383+
scaling_granularity=scaling_granularity,
365384
static_scale=torch.tensor([1.0], device="cuda"),
366385
)
367386
else:
368-
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
387+
cast_config_weight = CastConfig(
388+
scaling_type=scaling_type_weight,
389+
scaling_granularity=scaling_granularity,
390+
)
369391
if scaling_type_grad_output is ScalingType.STATIC:
370392
cast_config_grad_output = CastConfig(
371393
scaling_type=scaling_type_grad_output,
394+
scaling_granularity=scaling_granularity,
372395
static_scale=torch.tensor([1.0], device="cuda"),
373396
)
374397
else:
375-
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
398+
cast_config_grad_output = CastConfig(
399+
scaling_type=scaling_type_grad_output,
400+
scaling_granularity=scaling_granularity,
401+
)
376402

377403
config = Float8LinearConfig(
378404
cast_config_input=cast_config_input,

0 commit comments

Comments
 (0)