Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4fb0ada

Browse files
vkuzofacebook-github-bot
authored andcommitted
make single linear profiling script work with Float8 scaling type (#299)
Summary: Pull Request resolved: #299 Makes `benchmarks/bench_linear_float8.py` support per-tensor scaling configurations. Verified that performance is as we expect Reviewed By: drisspg Differential Revision: D59305789 fbshipit-source-id: a55df5b52a854cae6c9fcb01a6c1cad5bb1df340
1 parent 7a1bdab commit 4fb0ada

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17+
from float8_experimental.float8_linear import TensorScalingType
1718
from float8_experimental.float8_linear_utils import (
1819
get_float8_linear,
20+
linear_requires_sync,
1921
LinearType,
2022
sync_float8_amax_and_scale_history,
2123
)
@@ -68,6 +70,7 @@ class Experiment:
6870
compiled: bool
6971
use_fast_accum: bool
7072
linear_type: str
73+
scaling_repr: str
7174

7275
# 3 Times since we are calculating forward backward
7376
@property
@@ -96,10 +99,17 @@ def main(
9699
fast_accum_filter: Optional[bool] = None,
97100
shape_name_filter: Optional[str] = None,
98101
linear_type_filter: Optional[str] = None,
102+
scaling_type_x: str = "delayed",
103+
scaling_type_w: str = "delayed",
104+
scaling_type_dL_dY: str = "delayed",
99105
):
100106
device = "cuda"
101107
print(f"Compile is set to | {compile}")
102108

109+
scaling_type_x = TensorScalingType(scaling_type_x)
110+
scaling_type_w = TensorScalingType(scaling_type_w)
111+
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
112+
103113
# LLaMa 2 70B single-node weight shapes
104114
# assumes fused attn.wqkv and ffn.w13
105115
name_to_shapes_70b = {
@@ -134,9 +144,24 @@ def main(
134144
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
135145
)
136146

137-
linear_float8 = get_float8_linear(
138-
linear_type_enum, copy.deepcopy(linear_ref), emulate=False
139-
)
147+
if linear_type == "delayed":
148+
linear_float8 = get_float8_linear(
149+
linear_type_enum,
150+
copy.deepcopy(linear_ref),
151+
emulate=False,
152+
scaling_type_x=scaling_type_x,
153+
scaling_type_w=scaling_type_w,
154+
scaling_type_dL_dY=scaling_type_dL_dY,
155+
)
156+
scaling_repr = linear_float8.scaling_repr()
157+
else:
158+
linear_float8 = get_float8_linear(
159+
linear_type_enum,
160+
copy.deepcopy(linear_ref),
161+
emulate=False,
162+
)
163+
scaling_repr = None
164+
140165
if fast_accum:
141166
linear_float8.forward_config = ScaledMMConfig(False, True, False)
142167
else:
@@ -150,7 +175,10 @@ def main(
150175
if linear_type_enum == LinearType.DELAYED:
151176

152177
def float8_forw_backward():
153-
sync_float8_amax_and_scale_history(linear_float8)
178+
if linear_requires_sync(
179+
linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY
180+
):
181+
sync_float8_amax_and_scale_history(linear_float8)
154182
linear_float8(input_tensor).sum().backward()
155183

156184
else:
@@ -197,6 +225,7 @@ def wrapper(*args, **kwargs):
197225
compile,
198226
use_fast_accum=fast_accum,
199227
linear_type=linear_type,
228+
scaling_repr=scaling_repr,
200229
)
201230
print(experiment)
202231
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
@@ -209,6 +238,7 @@ def wrapper(*args, **kwargs):
209238
"K",
210239
"N",
211240
"linear_type",
241+
"scaling_repr",
212242
"ref_dtype",
213243
"compiled",
214244
"use_fast_accum",
@@ -228,6 +258,7 @@ def wrapper(*args, **kwargs):
228258
experiment.shape[1],
229259
experiment.shape[2],
230260
experiment.linear_type,
261+
experiment.scaling_repr,
231262
experiment.dtype,
232263
experiment.compiled,
233264
experiment.use_fast_accum,
@@ -257,6 +288,7 @@ def wrapper(*args, **kwargs):
257288
"name",
258289
"shape",
259290
"linear_type",
291+
"scaling_repr",
260292
"compiled",
261293
"use_fast_accum",
262294
"ref_time_sec",
@@ -280,15 +312,26 @@ def invoke_main() -> None:
280312
parser.add_argument("--fast_accum_filter", type=bool, required=False)
281313
parser.add_argument("--shape_name_filter", type=str, required=False)
282314
parser.add_argument("--linear_type_filter", type=str, required=False)
315+
parser.add_argument("--scaling_type_x", type=str, required=False)
316+
parser.add_argument("--scaling_type_w", type=str, required=False)
317+
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
283318
args = parser.parse_args()
284319
output_path = Path(args.output_path) if args.output_path is not None else None
320+
kwargs = {}
321+
if args.scaling_type_x is not None:
322+
kwargs["scaling_type_x"] = args.scaling_type_x
323+
if args.scaling_type_w is not None:
324+
kwargs["scaling_type_w"] = args.scaling_type_w
325+
if args.scaling_type_dL_dY is not None:
326+
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
285327
main(
286328
output_path,
287329
args.compile,
288330
args.n_limit,
289331
args.fast_accum_filter,
290332
args.shape_name_filter,
291333
args.linear_type_filter,
334+
**kwargs,
292335
)
293336

294337

float8_experimental/float8_linear.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
400400
self.float8_post_forward()
401401
return y
402402

403-
def extra_repr(self):
404-
# example: in_features=32, out_features=16, bias=True
405-
s = super().extra_repr()
403+
def scaling_repr(self):
406404
# add scaling settings without using too many characters
407-
scaling = f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
405+
# example: "x:del,w:del,dldy:dyn"
406+
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
408407

409-
s = f'{s}, scaling="{scaling}"'
410-
# example: in_features=32, out_features=16, bias=True, scaling="x:del,w:del,dldy:dyn"
408+
def extra_repr(self):
409+
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
411410
return s
412411

413412
@classmethod

0 commit comments

Comments
 (0)