Skip to content

Commit d759f81

Browse files
committed
add axiswise scaling to Float8Linear
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 304a542 ghstack-comment-id: 2368837904 Pull Request resolved: #920
1 parent 5711a01 commit d759f81

9 files changed

+450
-41
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
17+
from torchao.float8.config import (
18+
CastConfig,
19+
Float8LinearConfig,
20+
ScalingType,
21+
ScalingGranularity,
22+
)
1823
from torchao.float8.float8_linear import Float8Linear
1924
from torchao.float8.float8_linear_utils import (
2025
linear_requires_sync,
@@ -107,35 +112,49 @@ def main(
107112
scaling_type_input: str = "dynamic",
108113
scaling_type_weight: str = "dynamic",
109114
scaling_type_grad_output: str = "dynamic",
115+
scaling_granularity: str = "tensorwise",
110116
):
111117
device = "cuda"
112118
print(f"Compile is set to | {compile}")
113119

114120
scaling_type_input = ScalingType(scaling_type_input)
115121
scaling_type_weight = ScalingType(scaling_type_weight)
116122
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
123+
scaling_granularity = ScalingGranularity(scaling_granularity)
117124

118125
if scaling_type_input is ScalingType.STATIC:
119126
cast_config_input=CastConfig(
120127
scaling_type=scaling_type_input,
121128
static_scale=torch.tensor([1.0], device="cuda"),
129+
scaling_granularity=scaling_granularity,
122130
)
123131
else:
124-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
132+
cast_config_input=CastConfig(
133+
scaling_type=scaling_type_input,
134+
scaling_granularity=scaling_granularity,
135+
)
125136
if scaling_type_weight is ScalingType.STATIC:
126137
cast_config_weight=CastConfig(
127138
scaling_type=scaling_type_weight,
128139
static_scale=torch.tensor([1.0], device="cuda"),
140+
scaling_granularity=scaling_granularity,
129141
)
130142
else:
131-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
143+
cast_config_weight=CastConfig(
144+
scaling_type=scaling_type_weight,
145+
scaling_granularity=scaling_granularity,
146+
)
132147
if scaling_type_grad_output is ScalingType.STATIC:
133148
cast_config_grad_output=CastConfig(
134149
scaling_type=scaling_type_grad_output,
135150
static_scale=torch.tensor([1.0], device="cuda"),
151+
scaling_granularity=scaling_granularity,
136152
)
137153
else:
138-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
154+
cast_config_grad_output=CastConfig(
155+
scaling_type=scaling_type_grad_output,
156+
scaling_granularity=scaling_granularity,
157+
)
139158

140159
config = Float8LinearConfig(
141160
cast_config_input=cast_config_input,
@@ -167,7 +186,7 @@ def main(
167186
copy.deepcopy(linear_ref),
168187
config=config,
169188
)
170-
scaling_repr = linear_float8.scaling_repr()
189+
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
171190

172191
if fast_accum:
173192
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -310,6 +329,7 @@ def invoke_main() -> None:
310329
parser.add_argument("--scaling_type_input", type=str, required=False)
311330
parser.add_argument("--scaling_type_weight", type=str, required=False)
312331
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
332+
parser.add_argument("--scaling_granularity", type=str, required=False)
313333
args = parser.parse_args()
314334
output_path = Path(args.output_path) if args.output_path is not None else None
315335
kwargs = {}
@@ -327,6 +347,8 @@ def invoke_main() -> None:
327347
kwargs["scaling_type_weight"] = args.scaling_type_weight
328348
if args.scaling_type_grad_output is not None:
329349
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
350+
if args.scaling_granularity is not None:
351+
kwargs["scaling_granularity"] = args.scaling_granularity
330352
main(
331353
output_path,
332354
not args.disable_compile,

benchmarks/float8/bench_matmul.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.nn as nn
1414
import torch.utils.benchmark as benchmark
1515

16+
from torchao.float8.config import ScalingGranularity
17+
1618
from utils import (
1719
get_name_to_shapes_iter,
1820
profiler_output_to_filtered_time_by_kernel_name,
@@ -75,6 +77,7 @@ def run(
7577
K: Optional[int] = None,
7678
N: Optional[int] = None,
7779
use_gpu_kernel_time: bool = False,
80+
scaling_granularity: str = "tensorwise",
7881
):
7982
device = "cuda"
8083

@@ -84,6 +87,7 @@ def run(
8487
dtype = torch.bfloat16
8588
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
8689
fast_accum_vals = [True, False]
90+
scaling_granularity = ScalingGranularity(scaling_granularity)
8791

8892
for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
8993
if n_limit is not None and idx >= n_limit:
@@ -109,8 +113,13 @@ def run(
109113
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
110114
A = torch.zeros(M, K, device=device, dtype=d1)
111115
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
112-
scale_a = torch.tensor([1.0], device=device)
113-
scale_b = torch.tensor([1.0], device=device)
116+
if scaling_granularity == ScalingGranularity.TENSORWISE:
117+
scale_a = torch.tensor([1.0], device=device)
118+
scale_b = torch.tensor([1.0], device=device)
119+
else:
120+
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
121+
scale_a = torch.ones(M, 1, device=device)
122+
scale_b = torch.ones(1, N, device=device)
114123

115124
def do_matmul(A, B):
116125
nonlocal scale_a

benchmarks/float8/profile_linear_float8.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
import torch.nn as nn
2424
import torch.nn.functional as F
25-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
25+
from torchao.float8.config import (
26+
CastConfig,
27+
Float8LinearConfig,
28+
ScalingType,
29+
ScalingGranularity,
30+
)
2631
from torchao.float8.float8_linear_utils import (
2732
convert_to_float8_training,
2833
linear_requires_sync,
@@ -252,6 +257,7 @@ def main(
252257
scaling_type_input: str = "dynamic",
253258
scaling_type_weight: str = "dynamic",
254259
scaling_type_grad_output: str = "dynamic",
260+
scaling_granularity: str = "tensorwise",
255261
model_type: str = "linear",
256262
dtype_filter: str = "both",
257263
add_inductor_metadata_to_trace: bool = True,
@@ -263,28 +269,41 @@ def main(
263269
scaling_type_input = ScalingType(scaling_type_input)
264270
scaling_type_weight = ScalingType(scaling_type_weight)
265271
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272+
scaling_granularity = ScalingGranularity(scaling_granularity)
266273

267274
if scaling_type_input is ScalingType.STATIC:
268275
cast_config_input=CastConfig(
269276
scaling_type=scaling_type_input,
270277
static_scale=torch.tensor([1.0], device="cuda"),
278+
scaling_granularity=scaling_granularity,
271279
)
272280
else:
273-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
281+
cast_config_input=CastConfig(
282+
scaling_type=scaling_type_input,
283+
scaling_granularity=scaling_granularity,
284+
)
274285
if scaling_type_weight is ScalingType.STATIC:
275286
cast_config_weight=CastConfig(
276287
scaling_type=scaling_type_weight,
277288
static_scale=torch.tensor([1.0], device="cuda"),
289+
scaling_granularity=scaling_granularity,
278290
)
279291
else:
280-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
292+
cast_config_weight=CastConfig(
293+
scaling_type=scaling_type_weight,
294+
scaling_granularity=scaling_granularity,
295+
)
281296
if scaling_type_grad_output is ScalingType.STATIC:
282297
cast_config_grad_output=CastConfig(
283298
scaling_type=scaling_type_grad_output,
284299
static_scale=torch.tensor([1.0], device="cuda"),
300+
scaling_granularity=scaling_granularity,
285301
)
286302
else:
287-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
303+
cast_config_grad_output=CastConfig(
304+
scaling_type=scaling_type_grad_output,
305+
scaling_granularity=scaling_granularity,
306+
)
288307

289308
config = Float8LinearConfig(
290309
cast_config_input=cast_config_input,

test/float8/test_base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ def _test_linear_impl(
330330
"scaling_type_grad_output",
331331
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
332332
)
333+
@pytest.mark.parametrize(
334+
"scaling_granularity",
335+
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
336+
)
333337
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
334338
@pytest.mark.parametrize("linear_bias", [False, True])
335339
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@@ -340,6 +344,7 @@ def test_linear(
340344
scaling_type_input: ScalingType,
341345
scaling_type_weight: ScalingType,
342346
scaling_type_grad_output: ScalingType,
347+
scaling_granularity: ScalingGranularity,
343348
linear_dtype: torch.dtype,
344349
linear_bias: bool,
345350
):
@@ -352,30 +357,52 @@ def test_linear(
352357
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
353358
)
354359
pytest.skip()
360+
if scaling_granularity is ScalingGranularity.AXISWISE:
361+
if (
362+
scaling_type_input != ScalingType.DYNAMIC or
363+
scaling_type_weight != ScalingType.DYNAMIC or
364+
scaling_type_grad_output != ScalingType.DYNAMIC or
365+
linear_dtype != torch.bfloat16 or
366+
(not is_cuda_9_0)
367+
):
368+
pytest.skip()
369+
355370
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
356371
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
357372

358373
if scaling_type_input is ScalingType.STATIC:
359374
cast_config_input = CastConfig(
360375
scaling_type=scaling_type_input,
376+
scaling_granularity=scaling_granularity,
361377
static_scale=torch.tensor([1.0], device="cuda"),
362378
)
363379
else:
364-
cast_config_input = CastConfig(scaling_type=scaling_type_input)
380+
cast_config_input = CastConfig(
381+
scaling_type=scaling_type_input,
382+
scaling_granularity=scaling_granularity,
383+
)
365384
if scaling_type_weight is ScalingType.STATIC:
366385
cast_config_weight = CastConfig(
367386
scaling_type=scaling_type_weight,
387+
scaling_granularity=scaling_granularity,
368388
static_scale=torch.tensor([1.0], device="cuda"),
369389
)
370390
else:
371-
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
391+
cast_config_weight = CastConfig(
392+
scaling_type=scaling_type_weight,
393+
scaling_granularity=scaling_granularity,
394+
)
372395
if scaling_type_grad_output is ScalingType.STATIC:
373396
cast_config_grad_output = CastConfig(
374397
scaling_type=scaling_type_grad_output,
398+
scaling_granularity=scaling_granularity,
375399
static_scale=torch.tensor([1.0], device="cuda"),
376400
)
377401
else:
378-
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
402+
cast_config_grad_output = CastConfig(
403+
scaling_type=scaling_type_grad_output,
404+
scaling_granularity=scaling_granularity,
405+
)
379406

380407
config = Float8LinearConfig(
381408
cast_config_input=cast_config_input,

0 commit comments

Comments
 (0)