Skip to content

Commit e76db70

Browse files
authored
add axiswise scaling to Float8Linear (#920)
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX. Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags:
1 parent f81fe11 commit e76db70

9 files changed

+462
-55
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
17+
from torchao.float8.config import (
18+
CastConfig,
19+
Float8LinearConfig,
20+
ScalingType,
21+
ScalingGranularity,
22+
)
1823
from torchao.float8.float8_linear import Float8Linear
1924
from torchao.float8.float8_linear_utils import (
2025
linear_requires_sync,
@@ -107,35 +112,49 @@ def main(
107112
scaling_type_input: str = "dynamic",
108113
scaling_type_weight: str = "dynamic",
109114
scaling_type_grad_output: str = "dynamic",
115+
scaling_granularity: str = "tensorwise",
110116
):
111117
device = "cuda"
112118
print(f"Compile is set to | {compile}")
113119

114120
scaling_type_input = ScalingType(scaling_type_input)
115121
scaling_type_weight = ScalingType(scaling_type_weight)
116122
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
123+
scaling_granularity = ScalingGranularity(scaling_granularity)
117124

118125
if scaling_type_input is ScalingType.STATIC:
119126
cast_config_input=CastConfig(
120127
scaling_type=scaling_type_input,
121128
static_scale=torch.tensor([1.0], device="cuda"),
129+
scaling_granularity=scaling_granularity,
122130
)
123131
else:
124-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
132+
cast_config_input=CastConfig(
133+
scaling_type=scaling_type_input,
134+
scaling_granularity=scaling_granularity,
135+
)
125136
if scaling_type_weight is ScalingType.STATIC:
126137
cast_config_weight=CastConfig(
127138
scaling_type=scaling_type_weight,
128139
static_scale=torch.tensor([1.0], device="cuda"),
140+
scaling_granularity=scaling_granularity,
129141
)
130142
else:
131-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
143+
cast_config_weight=CastConfig(
144+
scaling_type=scaling_type_weight,
145+
scaling_granularity=scaling_granularity,
146+
)
132147
if scaling_type_grad_output is ScalingType.STATIC:
133148
cast_config_grad_output=CastConfig(
134149
scaling_type=scaling_type_grad_output,
135150
static_scale=torch.tensor([1.0], device="cuda"),
151+
scaling_granularity=scaling_granularity,
136152
)
137153
else:
138-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
154+
cast_config_grad_output=CastConfig(
155+
scaling_type=scaling_type_grad_output,
156+
scaling_granularity=scaling_granularity,
157+
)
139158

140159
config = Float8LinearConfig(
141160
cast_config_input=cast_config_input,
@@ -167,7 +186,7 @@ def main(
167186
copy.deepcopy(linear_ref),
168187
config=config,
169188
)
170-
scaling_repr = linear_float8.scaling_repr()
189+
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
171190

172191
if fast_accum:
173192
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -310,6 +329,7 @@ def invoke_main() -> None:
310329
parser.add_argument("--scaling_type_input", type=str, required=False)
311330
parser.add_argument("--scaling_type_weight", type=str, required=False)
312331
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
332+
parser.add_argument("--scaling_granularity", type=str, required=False)
313333
args = parser.parse_args()
314334
output_path = Path(args.output_path) if args.output_path is not None else None
315335
kwargs = {}
@@ -327,6 +347,8 @@ def invoke_main() -> None:
327347
kwargs["scaling_type_weight"] = args.scaling_type_weight
328348
if args.scaling_type_grad_output is not None:
329349
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
350+
if args.scaling_granularity is not None:
351+
kwargs["scaling_granularity"] = args.scaling_granularity
330352
main(
331353
output_path,
332354
not args.disable_compile,

benchmarks/float8/bench_matmul.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.nn as nn
1414
import torch.utils.benchmark as benchmark
1515

16+
from torchao.float8.config import ScalingGranularity
17+
1618
from utils import (
1719
get_name_to_shapes_iter,
1820
profiler_output_to_filtered_time_by_kernel_name,
@@ -75,6 +77,7 @@ def run(
7577
K: Optional[int] = None,
7678
N: Optional[int] = None,
7779
use_gpu_kernel_time: bool = False,
80+
scaling_granularity: str = "tensorwise",
7881
):
7982
device = "cuda"
8083

@@ -84,6 +87,7 @@ def run(
8487
dtype = torch.bfloat16
8588
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
8689
fast_accum_vals = [True, False]
90+
scaling_granularity = ScalingGranularity(scaling_granularity)
8791

8892
for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
8993
if n_limit is not None and idx >= n_limit:
@@ -109,8 +113,13 @@ def run(
109113
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
110114
A = torch.zeros(M, K, device=device, dtype=d1)
111115
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
112-
scale_a = torch.tensor([1.0], device=device)
113-
scale_b = torch.tensor([1.0], device=device)
116+
if scaling_granularity == ScalingGranularity.TENSORWISE:
117+
scale_a = torch.tensor([1.0], device=device)
118+
scale_b = torch.tensor([1.0], device=device)
119+
else:
120+
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
121+
scale_a = torch.ones(M, 1, device=device)
122+
scale_b = torch.ones(1, N, device=device)
114123

115124
def do_matmul(A, B):
116125
nonlocal scale_a

benchmarks/float8/profile_linear_float8.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
import torch.nn as nn
2424
import torch.nn.functional as F
25-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
25+
from torchao.float8.config import (
26+
CastConfig,
27+
Float8LinearConfig,
28+
ScalingType,
29+
ScalingGranularity,
30+
)
2631
from torchao.float8.float8_linear_utils import (
2732
convert_to_float8_training,
2833
linear_requires_sync,
@@ -252,6 +257,7 @@ def main(
252257
scaling_type_input: str = "dynamic",
253258
scaling_type_weight: str = "dynamic",
254259
scaling_type_grad_output: str = "dynamic",
260+
scaling_granularity: str = "tensorwise",
255261
model_type: str = "linear",
256262
dtype_filter: str = "both",
257263
add_inductor_metadata_to_trace: bool = True,
@@ -263,28 +269,41 @@ def main(
263269
scaling_type_input = ScalingType(scaling_type_input)
264270
scaling_type_weight = ScalingType(scaling_type_weight)
265271
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272+
scaling_granularity = ScalingGranularity(scaling_granularity)
266273

267274
if scaling_type_input is ScalingType.STATIC:
268275
cast_config_input=CastConfig(
269276
scaling_type=scaling_type_input,
270277
static_scale=torch.tensor([1.0], device="cuda"),
278+
scaling_granularity=scaling_granularity,
271279
)
272280
else:
273-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
281+
cast_config_input=CastConfig(
282+
scaling_type=scaling_type_input,
283+
scaling_granularity=scaling_granularity,
284+
)
274285
if scaling_type_weight is ScalingType.STATIC:
275286
cast_config_weight=CastConfig(
276287
scaling_type=scaling_type_weight,
277288
static_scale=torch.tensor([1.0], device="cuda"),
289+
scaling_granularity=scaling_granularity,
278290
)
279291
else:
280-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
292+
cast_config_weight=CastConfig(
293+
scaling_type=scaling_type_weight,
294+
scaling_granularity=scaling_granularity,
295+
)
281296
if scaling_type_grad_output is ScalingType.STATIC:
282297
cast_config_grad_output=CastConfig(
283298
scaling_type=scaling_type_grad_output,
284299
static_scale=torch.tensor([1.0], device="cuda"),
300+
scaling_granularity=scaling_granularity,
285301
)
286302
else:
287-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
303+
cast_config_grad_output=CastConfig(
304+
scaling_type=scaling_type_grad_output,
305+
scaling_granularity=scaling_granularity,
306+
)
288307

289308
config = Float8LinearConfig(
290309
cast_config_input=cast_config_input,

test/float8/test_base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ def _test_linear_impl(
324324
"scaling_type_grad_output",
325325
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
326326
)
327+
@pytest.mark.parametrize(
328+
"scaling_granularity",
329+
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
330+
)
327331
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
328332
@pytest.mark.parametrize("linear_bias", [False, True])
329333
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@@ -334,33 +338,56 @@ def test_linear(
334338
scaling_type_input: ScalingType,
335339
scaling_type_weight: ScalingType,
336340
scaling_type_grad_output: ScalingType,
341+
scaling_granularity: ScalingGranularity,
337342
linear_dtype: torch.dtype,
338343
linear_bias: bool,
339344
):
345+
if scaling_granularity is ScalingGranularity.AXISWISE:
346+
if (
347+
scaling_type_input != ScalingType.DYNAMIC or
348+
scaling_type_weight != ScalingType.DYNAMIC or
349+
scaling_type_grad_output != ScalingType.DYNAMIC or
350+
linear_dtype != torch.bfloat16 or
351+
(not is_cuda_9_0)
352+
):
353+
pytest.skip()
354+
340355
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
341356
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
342357

343358
if scaling_type_input is ScalingType.STATIC:
344359
cast_config_input = CastConfig(
345360
scaling_type=scaling_type_input,
361+
scaling_granularity=scaling_granularity,
346362
static_scale=torch.tensor([1.0], device="cuda"),
347363
)
348364
else:
349-
cast_config_input = CastConfig(scaling_type=scaling_type_input)
365+
cast_config_input = CastConfig(
366+
scaling_type=scaling_type_input,
367+
scaling_granularity=scaling_granularity,
368+
)
350369
if scaling_type_weight is ScalingType.STATIC:
351370
cast_config_weight = CastConfig(
352371
scaling_type=scaling_type_weight,
372+
scaling_granularity=scaling_granularity,
353373
static_scale=torch.tensor([1.0], device="cuda"),
354374
)
355375
else:
356-
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
376+
cast_config_weight = CastConfig(
377+
scaling_type=scaling_type_weight,
378+
scaling_granularity=scaling_granularity,
379+
)
357380
if scaling_type_grad_output is ScalingType.STATIC:
358381
cast_config_grad_output = CastConfig(
359382
scaling_type=scaling_type_grad_output,
383+
scaling_granularity=scaling_granularity,
360384
static_scale=torch.tensor([1.0], device="cuda"),
361385
)
362386
else:
363-
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
387+
cast_config_grad_output = CastConfig(
388+
scaling_type=scaling_type_grad_output,
389+
scaling_granularity=scaling_granularity,
390+
)
364391

365392
config = Float8LinearConfig(
366393
cast_config_input=cast_config_input,

0 commit comments

Comments
 (0)