Skip to content

Commit 988c5c9

Browse files
authored
fix tensor parallelism for float8 training with rowwise scaling (#1718)
Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires pytorch/torchtitan#808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 7b37eb0 commit 988c5c9

File tree

4 files changed

+113
-40
lines changed

4 files changed

+113
-40
lines changed

test/float8/test_dtensor.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,26 @@
2323

2424
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
2525
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
26-
from torch.distributed.tensor.parallel import parallelize_module
26+
from torch.distributed.tensor.parallel import (
27+
ColwiseParallel,
28+
PrepareModuleInput,
29+
RowwiseParallel,
30+
parallelize_module,
31+
)
2732
from torch.testing._internal.distributed._tensor.common_dtensor import (
2833
ModelArgs,
2934
Transformer,
3035
)
3136
from tqdm import tqdm
3237

3338
from torchao.float8 import Float8LinearConfig
34-
from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype
39+
from torchao.float8.config import (
40+
CastConfig,
41+
Float8LinearRecipeName,
42+
ScalingType,
43+
e4m3_dtype,
44+
recipe_name_to_linear_config,
45+
)
3546
from torchao.float8.float8_linear_utils import convert_to_float8_training
3647
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
3748
from torchao.float8.float8_tensor import (
@@ -49,6 +60,8 @@
4960
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
5061
from torchao.testing.float8.dtensor_utils import ToyModel
5162

63+
torch.set_float32_matmul_precision("high")
64+
5265

5366
def setup_distributed():
5467
world_size = int(os.environ.get("WORLD_SIZE", -1))
@@ -180,13 +193,17 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
180193

181194

182195
def _test_fp8_mlp_tensor_parallelism_base(
183-
mesh: DeviceMesh, size=16, compile: bool = False
196+
mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False
184197
):
185198
device = mesh.device_type
186-
# For now, only supports dynamic scaling of `x` and `dL_dY`.
187-
# TODO(future): add support for float8 all-gather with delayed scaling
188-
# for activations and gradients.
189-
config = Float8LinearConfig(emulate=True)
199+
200+
if rowwise:
201+
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
202+
# hack around config being frozen
203+
# TODO(future PR): we should make this nicer at the config level
204+
object.__setattr__(config, "emulate", True)
205+
else:
206+
config = Float8LinearConfig(emulate=True)
190207

191208
toy_model = ToyModel().to(device)
192209
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
@@ -196,14 +213,28 @@ def _test_fp8_mlp_tensor_parallelism_base(
196213
sp_model = copy.deepcopy(toy_model)
197214
sp_model = convert_to_float8_training(sp_model, config=config)
198215

216+
# For tensorwise scaling, enable float8 all_gather.
217+
# For rowwise scaling, keep high precision all_gather. Motivation for
218+
# not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
219+
# so for float8 all-gather we'd need to send two float8 copies per tensor,
220+
# which is similar # bytes over the wire than just doing bfloat16 all-gather.
221+
if rowwise:
222+
colwise_parallel_cls = ColwiseParallel
223+
rowwise_parallel_cls = RowwiseParallel
224+
prepare_input_cls = PrepareModuleInput
225+
else:
226+
colwise_parallel_cls = Float8ColwiseParallel
227+
rowwise_parallel_cls = Float8RowwiseParallel
228+
prepare_input_cls = PrepareFloat8ModuleInput
229+
199230
# vanilla TP
200231
tp_model = parallelize_module(
201232
tp_model,
202233
mesh,
203234
{
204-
"ffn.w1": Float8ColwiseParallel(),
205-
"ffn.w2": Float8ColwiseParallel(),
206-
"ffn.out_proj": Float8RowwiseParallel(),
235+
"ffn.w1": colwise_parallel_cls(),
236+
"ffn.w2": colwise_parallel_cls(),
237+
"ffn.out_proj": rowwise_parallel_cls(),
207238
},
208239
)
209240

@@ -212,33 +243,41 @@ def _test_fp8_mlp_tensor_parallelism_base(
212243
sp_model,
213244
mesh,
214245
{
215-
"ffn": PrepareFloat8ModuleInput(
246+
"ffn": prepare_input_cls(
216247
input_layouts=Shard(1), desired_input_layouts=Replicate()
217248
),
218-
"ffn.w1": Float8ColwiseParallel(),
219-
"ffn.w2": Float8ColwiseParallel(),
220-
"ffn.out_proj": Float8RowwiseParallel(
249+
"ffn.w1": colwise_parallel_cls(),
250+
"ffn.w2": colwise_parallel_cls(),
251+
"ffn.out_proj": rowwise_parallel_cls(
221252
output_layouts=Shard(1), use_local_output=False
222253
),
223254
},
224255
)
225256

226-
# PrepareFloat8ModuleInput with specific submodule fqn
257+
# prepare_input_cls with specific submodule fqn
227258
sp_model2 = copy.deepcopy(toy_model)
228259
sp_model2 = convert_to_float8_training(sp_model2, config=config)
229260

261+
if rowwise:
262+
prepare_input = prepare_input_cls(
263+
input_layouts=Shard(1),
264+
desired_input_layouts=Replicate(),
265+
)
266+
else:
267+
prepare_input = prepare_input_cls(
268+
input_layouts=Shard(1),
269+
desired_input_layouts=Replicate(),
270+
fwd_config_submodule_fqn="w2",
271+
)
272+
230273
sp_model2 = parallelize_module(
231274
sp_model2,
232275
mesh,
233276
{
234-
"ffn": PrepareFloat8ModuleInput(
235-
input_layouts=Shard(1),
236-
desired_input_layouts=Replicate(),
237-
fwd_config_submodule_fqn="w2",
238-
),
239-
"ffn.w1": Float8ColwiseParallel(),
240-
"ffn.w2": Float8ColwiseParallel(),
241-
"ffn.out_proj": Float8RowwiseParallel(
277+
"ffn": prepare_input,
278+
"ffn.w1": colwise_parallel_cls(),
279+
"ffn.w2": colwise_parallel_cls(),
280+
"ffn.out_proj": rowwise_parallel_cls(
242281
output_layouts=Shard(1), use_local_output=False
243282
),
244283
},
@@ -278,11 +317,13 @@ def _test_fp8_mlp_tensor_parallelism_base(
278317

279318

280319
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
281-
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)
320+
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False)
321+
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True)
282322

283323

284324
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
285-
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
325+
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False)
326+
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True)
286327

287328

288329
def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):

torchao/float8/float8_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,10 @@ def backward(ctx, grad_output):
168168
):
169169
# workaround from https://github.com/pytorch/pytorch/issues/141881
170170
# to avoid saving float8 weight from forward to backward when
171-
# FSDP is on
172-
weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0)
171+
# FSDP is on: add a fake dependency on `grad_output`.
172+
g_reshaped = grad_output.reshape(-1, grad_output.shape[-1]) * 0
173+
zero = g_reshaped[:1] * 0
174+
weight_hp_t = weight_hp_t + zero
173175

174176
# Note: we need https://github.com/pytorch/pytorch/issues/136267
175177
# to be solved to have a chance to reuse max(abs(weight, dim=...))

torchao/float8/float8_ops.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,25 @@ def float8_transpose(aten_op, args, kwargs=None):
113113

114114
@implements([aten.view.default])
115115
def float8_view(aten_op, args, kwargs=None):
116+
t, new_shape = args[0], args[1]
117+
118+
# if the new shape is the same as old, return an equivalent tensor
119+
# note that we have to create a new wrapper to make PyTorch internals happy
120+
if new_shape == list(t._data.shape):
121+
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
122+
return Float8Tensor(
123+
new_data,
124+
args[0]._scale,
125+
args[0]._orig_dtype,
126+
args[0]._linear_mm_config,
127+
args[0]._gemm_input_role,
128+
args[0]._axiswise_dim,
129+
)
130+
116131
if len(args[0]._scale.shape) < 2:
117132
# tensorwise scaling
118133
return float8_desugar_op(aten_op, args, kwargs)
119134

120-
t, new_shape = args[0], args[1]
121135
# for now, only support reshaping to [-1, dim] or [dim, -1]
122136
axiswise_dim = t._axiswise_dim
123137
if len(new_shape) == 2:
@@ -146,6 +160,7 @@ def float8_view(aten_op, args, kwargs=None):
146160
t._gemm_input_role,
147161
new_axiswise_dim,
148162
)
163+
149164
raise AssertionError(
150165
f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet."
151166
)

torchao/float8/float8_tensor_parallel.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ def _float8_linear_supports_float8_allgather(m):
3636

3737

3838
class Float8ColwiseParallel(ColwiseParallel):
39+
"""
40+
Like `ColwiseParallel`, but with all-gather in float8. This
41+
currently assumes tensorwise scaling.
42+
"""
43+
3944
@staticmethod
4045
def _prepare_input_fn(
4146
input_layouts, desired_input_layouts, mod, inputs, device_mesh
@@ -96,6 +101,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
96101

97102

98103
class Float8RowwiseParallel(RowwiseParallel):
104+
"""
105+
Like `RowwiseParallel`, but with all-gather in float8. This
106+
currently assumes tensorwise scaling.
107+
"""
108+
99109
@staticmethod
100110
def _prepare_input_fn(
101111
input_layouts, desired_input_layouts, mod, inputs, device_mesh
@@ -154,18 +164,23 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
154164

155165

156166
class PrepareFloat8ModuleInput(PrepareModuleInput):
157-
# subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that
158-
# after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor)
159-
# This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
160-
# so that if there are multiple float8 users of the input activation, we perform fp8 allgather
161-
# only once.
162-
# FP8 Args:
163-
# float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input,
164-
# we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn
165-
# fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used
166-
# for the float8 cast. If not specified, we will search for the Float8Linear in the submodules
167-
# and use the forward config from that module, in this case all module's forward config must be
168-
# the same.
167+
"""
168+
Like `PrepareModuleInput`, but with all-gather in float8. This
169+
currently assumes tensorwise scaling.
170+
171+
The only difference from `PrepareModuleInput` is that
172+
after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor)
173+
This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
174+
so that if there are multiple float8 users of the input activation, we perform fp8 allgather
175+
only once.
176+
FP8 Args:
177+
float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input,
178+
we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn
179+
fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used
180+
for the float8 cast. If not specified, we will search for the Float8Linear in the submodules
181+
and use the forward config from that module, in this case all module's forward config must be
182+
the same.
183+
"""
169184

170185
def __init__(
171186
self,

0 commit comments

Comments
 (0)