Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6fd2a08

Browse files
committed
Thread through the scaling type argument to float8 constructors
ghstack-source-id: 6f9b929 Pull Request resolved: #301
1 parent d4cf2ad commit 6fd2a08

10 files changed

+320
-107
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ repos:
1010
- id: trailing-whitespace
1111
- id: check-ast
1212
- id: check-merge-conflict
13-
- id: no-commit-to-branch
14-
args: ['--branch=main']
1513
- id: check-added-large-files
1614
args: ['--maxkb=500']
1715
- id: end-of-file-fixer

float8_experimental/float8_dynamic_linear.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Float8Tensor,
2020
merge_mm_configs,
2121
ScaledMMConfig,
22+
ScalingGranularity,
2223
tensor_already_casted_to_fp8,
2324
to_fp8_no_autograd,
2425
)
@@ -36,21 +37,26 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
3637
@staticmethod
3738
def forward(
3839
ctx,
39-
tensor,
40+
tensor: torch.Tensor,
4041
mm_config: ScaledMMConfig,
42+
scaling_granularity: ScalingGranularity,
4143
):
4244
ctx.mm_config = mm_config
45+
ctx.scaling_granularity = scaling_granularity
4346
return tensor
4447

4548
@staticmethod
46-
def backward(ctx, gradY):
49+
def backward(ctx, gradY: torch.Tensor):
4750
if tensor_already_casted_to_fp8(gradY):
48-
return gradY, None
49-
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
51+
return gradY, None, None
52+
gradY_scale = tensor_to_scale(gradY, e5m2_dtype, ctx.scaling_granularity)
5053
fp8_tensor = to_fp8_no_autograd(
51-
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
54+
gradY,
55+
gradY_scale,
56+
e5m2_dtype,
57+
mm_config=ctx.mm_config,
5258
)
53-
return fp8_tensor, None
59+
return fp8_tensor, None, None
5460

5561

5662
class Float8DynamicLinear(torch.nn.Linear):
@@ -63,13 +69,19 @@ def __init__(self, **super_kwargs):
6369
super().__init__(**super_kwargs)
6470

6571
def forward(self, input: torch.Tensor) -> torch.Tensor:
66-
x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config)
72+
x_fp8 = cast_to_float8_e4m3_dynamic(
73+
input, self.forward_config, self.scaling_granularity
74+
)
6775
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6876
w_fp8 = self.weight
6977
else:
70-
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
78+
w_fp8 = cast_to_float8_e4m3_dynamic(
79+
self.weight, self.forward_config, self.scaling_granularity
80+
)
7181
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
72-
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
82+
y = cast_to_float8_e5m2_dynamic_bw(
83+
y, self.backward_config, self.scaling_granularity
84+
)
7385
return y
7486

7587
@classmethod
@@ -101,9 +113,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101113
fp8_output=False,
102114
pad_inner_dim=config.pad_inner_dim,
103115
)
116+
# TODO: For now hardcode TensorWise scaling
117+
new_mod.scaling_granularity = ScalingGranularity.TensorWise
118+
104119
if config.enable_fsdp_fp8_all_gather:
105120
new_mod.weight = nn.Parameter(
106-
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
121+
WeightWithDynamicFloat8CastTensor(
122+
mod.weight, new_mod.forward_config, new_mod.scaling_granularity
123+
)
107124
)
108125
else:
109126
new_mod.weight = mod.weight
@@ -112,18 +129,31 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112129

113130

114131
def cast_to_float8_e4m3_dynamic(
115-
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
132+
inpt_tensor: torch.Tensor,
133+
mm_config: ScaledMMConfig,
134+
scaling_granularity: ScalingGranularity,
135+
reduce_amax: bool = False,
116136
) -> Float8Tensor:
117137
if tensor_already_casted_to_fp8(inpt_tensor):
118138
return inpt_tensor
119-
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
120-
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
139+
scale = tensor_to_scale(
140+
inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax
141+
)
142+
return Float8Tensor.to_float8(
143+
inpt_tensor,
144+
scale,
145+
e4m3_dtype,
146+
mm_config=mm_config,
147+
scaling_granularity=scaling_granularity,
148+
)
121149

122150

123151
def cast_to_float8_e5m2_dynamic_bw(
124-
gradY: torch.Tensor, mm_config: ScaledMMConfig
152+
gradY: torch.Tensor,
153+
mm_config: ScaledMMConfig,
154+
scaling_granularity: ScalingGranularity,
125155
) -> torch.Tensor:
126-
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
156+
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_granularity)
127157

128158

129159
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -143,7 +173,12 @@ def cast_to_float8_e5m2_dynamic_bw(
143173

144174
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
145175
@staticmethod
146-
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
176+
def __new__(
177+
cls,
178+
tensor: torch.Tensor,
179+
mm_config: ScaledMMConfig,
180+
scaling_granularity: ScalingGranularity,
181+
):
147182
return torch.Tensor._make_wrapper_subclass(
148183
cls,
149184
tensor.size(),
@@ -157,24 +192,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
157192
requires_grad=tensor.requires_grad,
158193
)
159194

160-
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
195+
def __init__(
196+
self,
197+
tensor: torch.Tensor,
198+
mm_config: ScaledMMConfig,
199+
scaling_granularity: ScalingGranularity,
200+
):
161201
self._tensor = tensor
162202
self._mm_config = mm_config
203+
self._scaling_granularity = scaling_granularity
163204

164205
@classmethod
165206
def __torch_dispatch__(cls, func, types, args, kwargs=None):
166207
if func == torch.ops.aten.detach.default:
167208
return WeightWithDynamicFloat8CastTensor(
168-
args[0]._tensor, args[0]._mm_config
209+
args[0]._tensor, args[0]._mm_config, args[0]._scaling_granularity
169210
)
170211
mm_config: Optional[ScaledMMConfig] = None
212+
scaling_granularity: Optional[ScalingGranularity] = None
171213

172214
def unwrap(t):
173215
nonlocal mm_config
216+
nonlocal scaling_granularity
174217
if mm_config is None:
175218
mm_config = t._mm_config
176219
else:
177220
mm_config = merge_mm_configs(mm_config, t._mm_config)
221+
222+
if scaling_granularity is None:
223+
scaling_granularity = t._scaling_granularity
224+
else:
225+
# TODO For now we assume that the scaling granularity is same across all tensors
226+
assert scaling_granularity == t._scaling_granularity
178227
return t._tensor
179228

180229
args, kwargs = pytree.tree_map_only(
@@ -184,23 +233,33 @@ def unwrap(t):
184233
if func not in _ops_to_preserve_subclass:
185234
return out
186235
return pytree.tree_map_only(
187-
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
236+
torch.Tensor,
237+
lambda x: WeightWithDynamicFloat8CastTensor(
238+
x, mm_config, scaling_granularity
239+
),
240+
out,
188241
)
189242

190243
def __tensor_flatten__(self):
191-
return ["_tensor"], self._mm_config
244+
return ["_tensor"], {
245+
"_mm_config": self._mm_config,
246+
"_scaling_granularity": self._scaling_granularity,
247+
}
192248

193249
@staticmethod
194250
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
195-
mm_config = flatten_spec
196-
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
251+
mm_config = flatten_spec["_mm_config"]
252+
scaling_granularity = flatten_spec["_scaling_granularity"]
253+
return WeightWithDynamicFloat8CastTensor(
254+
inner_tensors["_tensor"], mm_config, scaling_granularity
255+
)
197256

198257
def __repr__(self):
199-
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
258+
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, scaling_granularity={self._scaling_granularity})"
200259

201260
def fsdp_pre_all_gather(self, mesh):
202261
float8_tensor = cast_to_float8_e4m3_dynamic(
203-
self._tensor, self._mm_config, reduce_amax=True
262+
self._tensor, self._mm_config, self._scaling_granularity, reduce_amax=True
204263
)
205264
return (float8_tensor._data,), (float8_tensor._scale,)
206265

float8_experimental/float8_linear.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from float8_experimental.float8_tensor import (
2626
Float8Tensor,
2727
ScaledMMConfig,
28+
ScalingGranularity,
2829
to_fp8_no_autograd,
2930
)
3031

@@ -45,6 +46,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
4546
float8_dtype,
4647
is_initialized,
4748
reduce_amax,
49+
scaling_granularity: ScalingGranularity,
4850
):
4951
"""
5052
If x is about to be cast to `float8` and the amax buffers are not initialized,
@@ -56,7 +58,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
5658
# Note: we need to enable distributed reduction here in order
5759
# to match numerics between single GPU and multi GPU code for
5860
# activations and gradients
59-
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
61+
new_amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
6062
cur_amax.fill_(new_amax)
6163
amax_history[0] = new_amax
6264
new_scale = amax_history_to_scale(
@@ -82,11 +84,13 @@ def forward(
8284
scale_fn_name,
8385
is_amax_initialized,
8486
mm_config: ScaledMMConfig,
87+
scaling_granularity: ScalingGranularity,
8588
):
8689
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
8790
ctx.scale_fn_name = scale_fn_name
8891
ctx.is_amax_initialized = is_amax_initialized
8992
ctx.mm_config = mm_config
93+
ctx.scaling_granularity = scaling_granularity
9094
return tensor
9195

9296
@staticmethod
@@ -104,14 +108,18 @@ def backward(ctx, go):
104108
e5m2_dtype,
105109
is_amax_initialized,
106110
reduce_amax=True,
111+
scaling_granularity=ctx.scaling_granularity,
107112
)
108113

109-
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
114+
fp8_amax_dL_dY.fill_(tensor_to_amax(go, ctx.scaling_granularity))
110115

111116
res = to_fp8_no_autograd(
112-
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
117+
go,
118+
fp8_scale_dL_dY,
119+
e5m2_dtype,
120+
mm_config=ctx.mm_config,
113121
)
114-
empty_grads = None, None, None, None, None, None
122+
empty_grads = None, None, None, None, None, None, None
115123
return res, *empty_grads
116124

117125

@@ -196,6 +204,10 @@ def __init__(self, *args, **kwargs):
196204
emulate, False, False, config.pad_inner_dim
197205
)
198206

207+
# Defines the scaling granularity for the forward and backwards pass
208+
# TODO: For now hardcode TensorWise scaling
209+
self.scaling_granularity = ScalingGranularity.TensorWise
210+
199211
# Note: is_amax_initialized is not a buffer to avoid data dependent
200212
# control flow visible to dynamo
201213
# TODO(future PR): add serialization for this flag
@@ -298,6 +310,7 @@ def cast_x_to_float8(
298310
e4m3_dtype,
299311
is_amax_initialized,
300312
reduce_amax=True,
313+
scaling_granularity=self.scaling_granularity,
301314
)
302315
x_fp8 = Float8Tensor.to_float8(
303316
x,
@@ -308,7 +321,9 @@ def cast_x_to_float8(
308321
)
309322
else:
310323
assert self.scaling_type_x is TensorScalingType.DYNAMIC
311-
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
324+
x_fp8 = cast_to_float8_e4m3_dynamic(
325+
x, self.forward_config, self.scaling_granularity
326+
)
312327
return x_fp8
313328

314329
def cast_w_to_float8(
@@ -325,6 +340,7 @@ def cast_w_to_float8(
325340
e4m3_dtype,
326341
is_amax_initialized,
327342
reduce_amax=False,
343+
scaling_granularity=self.scaling_granularity,
328344
)
329345

330346
w_fp8 = Float8Tensor.to_float8(
@@ -340,7 +356,9 @@ def cast_w_to_float8(
340356
if isinstance(self.weight, Float8Tensor): # cast by FSDP
341357
w_fp8 = self.weight
342358
else:
343-
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
359+
w_fp8 = cast_to_float8_e4m3_dynamic(
360+
self.weight, self.forward_config, self.scaling_granularity
361+
)
344362
return w_fp8
345363

346364
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
@@ -354,10 +372,13 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
354372
scale_fn_name,
355373
self.is_amax_initialized,
356374
self.backward_config,
375+
self.scaling_granularity,
357376
)
358377
else:
359378
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
360-
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
379+
y = cast_to_float8_e5m2_dynamic_bw(
380+
y, self.backward_config, self.scaling_granularity
381+
)
361382
return y
362383

363384
def float8_pre_forward(self, x):
@@ -440,7 +461,9 @@ def from_float(
440461
and config.enable_fsdp_fp8_all_gather
441462
):
442463
new_mod.weight = torch.nn.Parameter(
443-
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
464+
WeightWithDynamicFloat8CastTensor(
465+
mod.weight, new_mod.forward_config, new_mod.scaling_granularity
466+
)
444467
)
445468
else:
446469
assert not config.enable_fsdp_fp8_all_gather, "unsupported"

0 commit comments

Comments
 (0)