Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 8352894

Browse files
vkuzofacebook-github-bot
authored andcommitted
rename all variables to use input/weight/grad_output notation (#335)
Summary: Pull Request resolved: #335 In #323 we changed the user facing variable notation from `x/w/dL_dY` to `input/weight/grad_output`. This PR follows up by changing most of the internal variables to also match the new notation, to reduce confusion. Reviewed By: weifengpy Differential Revision: D60252071 fbshipit-source-id: b91ec5b975df550962418eafc93f1904d64a3dd8
1 parent b9b606e commit 8352894

9 files changed

+100
-94
lines changed

float8_experimental/float8_dynamic_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def backward(ctx, gradY):
4242
gradY_scale,
4343
e5m2_dtype,
4444
linear_mm_config=ctx.linear_mm_config,
45-
gemm_input_role=GemmInputRole.DL_DY,
45+
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
4646
)
4747
return fp8_tensor, None
4848

@@ -51,7 +51,7 @@ def cast_to_float8_e4m3_dynamic(
5151
inpt_tensor: torch.Tensor,
5252
linear_mm_config: LinearMMConfig,
5353
reduce_amax: bool = False,
54-
gemm_input_role: GemmInputRole = GemmInputRole.X,
54+
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
5555
) -> Float8Tensor:
5656
if tensor_already_casted_to_fp8(inpt_tensor):
5757
return inpt_tensor

float8_experimental/float8_linear.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def backward(ctx, go):
125125
fp8_scale_grad_output,
126126
e5m2_dtype,
127127
linear_mm_config=ctx.linear_mm_config,
128-
gemm_input_role=GemmInputRole.DL_DY,
128+
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
129129
)
130130
empty_grads = None, None, None, None, None, None
131131
return res, *empty_grads
@@ -273,21 +273,21 @@ def convert_amax_buffer_to_float32(self):
273273
if self._buffers[key] is not None:
274274
self._buffers[key] = self._buffers[key].to(torch.float32)
275275

276-
def cast_x_to_float8(
277-
self, x: torch.Tensor, is_amax_initialized: bool
276+
def cast_input_to_float8(
277+
self, input: torch.Tensor, is_amax_initialized: bool
278278
) -> torch.Tensor:
279279
# Duplicate the autocast logic for F.linear, so that the output
280280
# of our module has the right original precision
281281
if torch.is_autocast_enabled():
282282
# For now, hardcode to GPU's autocast dtype
283283
# if we need CPU support in the future, we can add it
284284
autocast_dtype = torch.get_autocast_gpu_dtype()
285-
x = x.to(autocast_dtype)
285+
input = input.to(autocast_dtype)
286286

287287
if self.scaling_type_input is TensorScalingType.DELAYED:
288288
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
289289
_maybe_initialize_amaxes_scales_for_float8_cast(
290-
x,
290+
input,
291291
self.fp8_amax_input,
292292
self.fp8_amax_history_input,
293293
self.fp8_scale_input,
@@ -296,29 +296,29 @@ def cast_x_to_float8(
296296
is_amax_initialized,
297297
reduce_amax=True,
298298
)
299-
x_fp8 = Float8Tensor.to_float8(
300-
x,
299+
input_fp8 = Float8Tensor.to_float8(
300+
input,
301301
self.fp8_scale_input,
302302
e4m3_dtype,
303303
self.fp8_amax_input,
304304
linear_mm_config=self.linear_mm_config,
305-
gemm_input_role=GemmInputRole.X,
305+
gemm_input_role=GemmInputRole.INPUT,
306306
)
307307
else:
308308
assert self.scaling_type_input is TensorScalingType.DYNAMIC
309-
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
310-
return x_fp8
309+
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
310+
return input_fp8
311311

312-
def cast_w_to_float8(
313-
self, w: torch.Tensor, is_amax_initialized: bool
312+
def cast_weight_to_float8(
313+
self, weight: torch.Tensor, is_amax_initialized: bool
314314
) -> torch.Tensor:
315315
if self.scaling_type_weight is TensorScalingType.DELAYED:
316316
if isinstance(self.weight, Float8Tensor): # cast by FSDP
317-
w_fp8 = self.weight
317+
weight_fp8 = self.weight
318318
else:
319319
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
320320
_maybe_initialize_amaxes_scales_for_float8_cast(
321-
w,
321+
weight,
322322
self.fp8_amax_weight,
323323
self.fp8_amax_history_weight,
324324
self.fp8_scale_weight,
@@ -328,29 +328,31 @@ def cast_w_to_float8(
328328
reduce_amax=False,
329329
)
330330

331-
w_fp8 = Float8Tensor.to_float8(
332-
w,
331+
weight_fp8 = Float8Tensor.to_float8(
332+
weight,
333333
self.fp8_scale_weight,
334334
e4m3_dtype,
335335
self.fp8_amax_weight,
336336
linear_mm_config=self.linear_mm_config,
337-
gemm_input_role=GemmInputRole.W,
337+
gemm_input_role=GemmInputRole.WEIGHT,
338338
)
339339
else:
340340
assert self.scaling_type_weight is TensorScalingType.DYNAMIC
341341
if isinstance(self.weight, Float8Tensor): # cast by FSDP
342-
w_fp8 = self.weight
342+
weight_fp8 = self.weight
343343
else:
344-
w_fp8 = cast_to_float8_e4m3_dynamic(
345-
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
344+
weight_fp8 = cast_to_float8_e4m3_dynamic(
345+
self.weight,
346+
self.linear_mm_config,
347+
gemm_input_role=GemmInputRole.WEIGHT,
346348
)
347-
return w_fp8
349+
return weight_fp8
348350

349-
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
351+
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
350352
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
351353
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
352-
y = NoopFwToFloat8E5M2Bw.apply(
353-
y,
354+
output = NoopFwToFloat8E5M2Bw.apply(
355+
output,
354356
self.fp8_amax_grad_output,
355357
self.fp8_amax_history_grad_output,
356358
self.fp8_scale_grad_output,
@@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
360362
)
361363
else:
362364
assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC
363-
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
364-
return y
365+
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
366+
return output
365367

366-
def float8_pre_forward(self, x):
368+
def float8_pre_forward(self, input):
367369
if not self.enable_pre_and_post_forward:
368370
return
369371
if (
@@ -374,7 +376,7 @@ def float8_pre_forward(self, x):
374376
raise AssertionError(
375377
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
376378
)
377-
self.last_seen_input_dtype = x.dtype
379+
self.last_seen_input_dtype = input.dtype
378380

379381
def float8_post_forward(self):
380382
if not self.enable_pre_and_post_forward:
@@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
388390
if self.has_any_delayed_scaling:
389391
self.float8_pre_forward(input)
390392

391-
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
392-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
393+
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
394+
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)
393395

394-
y = torch.matmul(x_fp8, w_fp8.t())
396+
output = torch.matmul(input_fp8, weight_fp8.t())
395397

396-
# Cast gradY to float8_e5m2 during backward
397-
y = self.cast_y_to_float8_in_bw(y)
398+
# Cast grad_output to float8_e5m2 during backward
399+
output = self.cast_output_to_float8_in_bw(output)
398400

399401
if self.bias is not None:
400-
y = y + self.bias.to(y.dtype)
402+
output = output + self.bias.to(output.dtype)
401403

402404
if self.has_any_delayed_scaling:
403405
self.float8_post_forward()
404-
return y
406+
return output
405407

406408
def scaling_repr(self):
407409
# add scaling settings without using too many characters
408-
# example: "x:del,w:del,dldy:dyn"
409-
return f"x:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},dldy:{self.scaling_type_grad_output.short_str()}"
410+
# example: "i:del,w:del,go:dyn"
411+
return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}"
410412

411413
def extra_repr(self):
412414
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'

float8_experimental/float8_tensor.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,21 @@
2727
#
2828
# There are three gemms in a forward + backward of a Linear layer:
2929
#
30-
# 1. x @ w_t = y (forward pass)
31-
# 2. dL_dY @ w = dL_dX (backward pass)
32-
# 3. x_t @ dL_dY = dL_dW (backward pass)
30+
# 1. input @ weight_t = output (forward pass)
31+
# 2. grad_output @ weight = grad_input (backward pass)
32+
# 3. input_t @ grad_output = grad_weight (backward pass)
3333
#
3434
# In the formulas above, there are:
35-
# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t).
36-
# - Note that dL_dY_t is implied because of memory format requirements
35+
# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t).
36+
# - Note that grad_output_t is implied because of memory format requirements
3737
# of float8 gemms
38-
# B. three output tensors (y, dL_dX, dL_dW)
38+
# B. three output tensors (output, grad_input, grad_weight)
3939
#
4040
# We want each input tensor, gemm, and output tensor to be configurable.
4141
# The state of this configuration today is:
4242
#
4343
# i. pairs of input tensors (non-t and t variants) have their scaling
44-
# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear
44+
# configurable via the scaling_type_* arguments to Float8Linear
4545
# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing
4646
# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed
4747
# to configure all three gemms, also not user facing
@@ -60,11 +60,12 @@
6060

6161
# The object below is not user facing and exists for convenience,
6262
# to allow Float8Tensor to use
63-
# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is
63+
# the right config based on which gemm from gemms with outputs
64+
# `output`, `grad_input`, `grad_weight` is
6465
# being called.
6566
LinearMMConfig = namedtuple(
6667
"LinearMMConfig",
67-
["y", "dL_dX", "dL_dW"],
68+
["output", "grad_input", "grad_weight"],
6869
defaults=[
6970
ScaledMMConfig(False, True, False, False),
7071
ScaledMMConfig(False, False, False, False),
@@ -81,9 +82,9 @@ class GemmInputRole(enum.Enum):
8182
gemm is performed.
8283
"""
8384

84-
X = "x"
85-
W = "w"
86-
DL_DY = "dL_dY"
85+
INPUT = "input"
86+
WEIGHT = "weight"
87+
GRAD_OUTPUT = "grad_output"
8788

8889

8990
# choose which scaled_mm_config to use based on gemm inputs
@@ -93,21 +94,21 @@ def choose_scaled_mm_config(
9394
b_role: GemmInputRole,
9495
b_linear_mm_config: LinearMMConfig,
9596
):
96-
if a_role is GemmInputRole.X and b_role is GemmInputRole.W:
97+
if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT:
9798
assert (
98-
a_linear_mm_config.y == b_linear_mm_config.y
99-
), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}"
100-
return a_linear_mm_config.y
101-
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W:
99+
a_linear_mm_config.output == b_linear_mm_config.output
100+
), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}"
101+
return a_linear_mm_config.output
102+
elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT:
102103
assert (
103-
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
104-
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
105-
return a_linear_mm_config.dL_dX
106-
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
104+
a_linear_mm_config.grad_input == b_linear_mm_config.grad_input
105+
), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}"
106+
return a_linear_mm_config.grad_input
107+
elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT:
107108
assert (
108-
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
109-
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"
110-
return a_linear_mm_config.dL_dW
109+
a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight
110+
), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}"
111+
return a_linear_mm_config.grad_weight
111112
else:
112113
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")
113114

@@ -207,7 +208,7 @@ def forward(
207208
float8_dtype=e4m3_dtype,
208209
amax_buffer: Optional[torch.Tensor] = None,
209210
linear_mm_config: Optional[LinearMMConfig] = None,
210-
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
211+
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
211212
):
212213
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
213214
Args
@@ -287,7 +288,7 @@ def __new__(
287288
scale: torch.Tensor,
288289
orig_dtype: torch.dtype,
289290
linear_mm_config: Optional[LinearMMConfig],
290-
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
291+
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
291292
):
292293
assert (
293294
scale.numel() == 1
@@ -348,7 +349,7 @@ def to_float8(
348349
float8_dtype: torch.dtype,
349350
amax_buffer: Optional[torch.Tensor] = None,
350351
linear_mm_config: Optional[LinearMMConfig] = None,
351-
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
352+
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
352353
):
353354
"""Converts a higher precision tensor to float8 in a differentiable way.
354355

float8_experimental/float8_tensor_parallel.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _prepare_input_fn(
4848
input_tensor = cast_to_float8_e4m3_dynamic(
4949
input_tensor,
5050
mod.linear_mm_config,
51-
gemm_input_role=GemmInputRole.X,
51+
gemm_input_role=GemmInputRole.INPUT,
5252
) # DTensor(Float8Tensor)
5353

5454
# transform the input layouts to the desired layouts of ColwiseParallel
@@ -101,7 +101,7 @@ def _prepare_input_fn(
101101
input_tensor = cast_to_float8_e4m3_dynamic(
102102
input_tensor,
103103
mod.linear_mm_config,
104-
gemm_input_role=GemmInputRole.X,
104+
gemm_input_role=GemmInputRole.INPUT,
105105
) # DTensor(Float8Tensor)
106106

107107
if input_layouts != desired_input_layouts:
@@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
199199
dt_inp = cast_to_float8_e4m3_dynamic(
200200
dt_inp,
201201
self.linear_mm_config,
202-
gemm_input_role=GemmInputRole.X,
202+
gemm_input_role=GemmInputRole.INPUT,
203203
) # DTensor(Float8Tensor)
204204
if desired_layout is not None and input_layout != desired_layout:
205205
dt_inp = dt_inp.redistribute(placements=(desired_layout,))

float8_experimental/fsdp_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,14 @@ def fsdp_pre_all_gather(self, mesh):
169169
self._precomputed_scale,
170170
torch.float8_e4m3fn,
171171
linear_mm_config=self._linear_mm_config,
172-
gemm_input_role=GemmInputRole.W,
172+
gemm_input_role=GemmInputRole.WEIGHT,
173173
)
174174
else:
175175
float8_tensor = cast_to_float8_e4m3_dynamic(
176176
self._tensor,
177177
self._linear_mm_config,
178178
reduce_amax=True,
179-
gemm_input_role=GemmInputRole.W,
179+
gemm_input_role=GemmInputRole.WEIGHT,
180180
)
181181
return (float8_tensor._data,), (float8_tensor._scale,)
182182

@@ -199,7 +199,7 @@ def fsdp_post_all_gather(
199199
scale,
200200
param_dtype,
201201
self._linear_mm_config,
202-
gemm_input_role=GemmInputRole.W,
202+
gemm_input_role=GemmInputRole.WEIGHT,
203203
), (data,)
204204

205205

@@ -362,7 +362,7 @@ def fsdp_pre_all_gather(self, mesh):
362362
e4m3_dtype,
363363
self._amax_buffer,
364364
self._linear_mm_config,
365-
gemm_input_role=GemmInputRole.W,
365+
gemm_input_role=GemmInputRole.WEIGHT,
366366
)
367367
return (float8_tensor._data,), (float8_tensor._scale,)
368368

@@ -385,5 +385,5 @@ def fsdp_post_all_gather(
385385
scale,
386386
param_dtype,
387387
self._linear_mm_config,
388-
gemm_input_role=GemmInputRole.W,
388+
gemm_input_role=GemmInputRole.WEIGHT,
389389
), (data,)

float8_experimental/inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
132132
scale,
133133
dtype,
134134
self.linear_mm_config,
135-
gemm_input_role=GemmInputRole.W,
135+
gemm_input_role=GemmInputRole.WEIGHT,
136136
)
137137
self.weight = nn.Parameter(quantized_weight)
138138
self.weight.requires_grad = False
@@ -205,7 +205,7 @@ def cast_to_float8_e4m3_inference(
205205
scale,
206206
e4m3_dtype,
207207
linear_mm_config=linear_mm_config,
208-
gemm_input_role=GemmInputRole.X,
208+
gemm_input_role=GemmInputRole.INPUT,
209209
)
210210

211211

0 commit comments

Comments
 (0)