23
23
24
24
from torch .distributed ._tensor import DTensor , Replicate , Shard , distribute_tensor
25
25
from torch .distributed .device_mesh import DeviceMesh , init_device_mesh
26
- from torch .distributed .tensor .parallel import parallelize_module
26
+ from torch .distributed .tensor .parallel import (
27
+ ColwiseParallel ,
28
+ PrepareModuleInput ,
29
+ RowwiseParallel ,
30
+ parallelize_module ,
31
+ )
27
32
from torch .testing ._internal .distributed ._tensor .common_dtensor import (
28
33
ModelArgs ,
29
34
Transformer ,
30
35
)
31
36
from tqdm import tqdm
32
37
33
38
from torchao .float8 import Float8LinearConfig
34
- from torchao .float8 .config import CastConfig , ScalingType , e4m3_dtype
39
+ from torchao .float8 .config import (
40
+ CastConfig ,
41
+ Float8LinearRecipeName ,
42
+ ScalingType ,
43
+ e4m3_dtype ,
44
+ recipe_name_to_linear_config ,
45
+ )
35
46
from torchao .float8 .float8_linear_utils import convert_to_float8_training
36
47
from torchao .float8 .float8_scaling_utils import NoopFwToFloat8BwDynamic
37
48
from torchao .float8 .float8_tensor import (
49
60
from torchao .float8 .fsdp_utils import WeightWithDynamicFloat8CastTensor
50
61
from torchao .testing .float8 .dtensor_utils import ToyModel
51
62
63
+ torch .set_float32_matmul_precision ("high" )
64
+
52
65
53
66
def setup_distributed ():
54
67
world_size = int (os .environ .get ("WORLD_SIZE" , - 1 ))
@@ -180,13 +193,17 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
180
193
181
194
182
195
def _test_fp8_mlp_tensor_parallelism_base (
183
- mesh : DeviceMesh , size = 16 , compile : bool = False
196
+ mesh : DeviceMesh , size = 16 , compile : bool = False , rowwise : bool = False
184
197
):
185
198
device = mesh .device_type
186
- # For now, only supports dynamic scaling of `x` and `dL_dY`.
187
- # TODO(future): add support for float8 all-gather with delayed scaling
188
- # for activations and gradients.
189
- config = Float8LinearConfig (emulate = True )
199
+
200
+ if rowwise :
201
+ config = recipe_name_to_linear_config (Float8LinearRecipeName .ALL_AXISWISE )
202
+ # hack around config being frozen
203
+ # TODO(future PR): we should make this nicer at the config level
204
+ object .__setattr__ (config , "emulate" , True )
205
+ else :
206
+ config = Float8LinearConfig (emulate = True )
190
207
191
208
toy_model = ToyModel ().to (device )
192
209
toy_model_fp8 = convert_to_float8_training (toy_model , config = config )
@@ -196,14 +213,28 @@ def _test_fp8_mlp_tensor_parallelism_base(
196
213
sp_model = copy .deepcopy (toy_model )
197
214
sp_model = convert_to_float8_training (sp_model , config = config )
198
215
216
+ # For tensorwise scaling, enable float8 all_gather.
217
+ # For rowwise scaling, keep high precision all_gather. Motivation for
218
+ # not doing float8 all-gather for rowwise: tensors need to be scaled both ways,
219
+ # so for float8 all-gather we'd need to send two float8 copies per tensor,
220
+ # which is similar # bytes over the wire than just doing bfloat16 all-gather.
221
+ if rowwise :
222
+ colwise_parallel_cls = ColwiseParallel
223
+ rowwise_parallel_cls = RowwiseParallel
224
+ prepare_input_cls = PrepareModuleInput
225
+ else :
226
+ colwise_parallel_cls = Float8ColwiseParallel
227
+ rowwise_parallel_cls = Float8RowwiseParallel
228
+ prepare_input_cls = PrepareFloat8ModuleInput
229
+
199
230
# vanilla TP
200
231
tp_model = parallelize_module (
201
232
tp_model ,
202
233
mesh ,
203
234
{
204
- "ffn.w1" : Float8ColwiseParallel (),
205
- "ffn.w2" : Float8ColwiseParallel (),
206
- "ffn.out_proj" : Float8RowwiseParallel (),
235
+ "ffn.w1" : colwise_parallel_cls (),
236
+ "ffn.w2" : colwise_parallel_cls (),
237
+ "ffn.out_proj" : rowwise_parallel_cls (),
207
238
},
208
239
)
209
240
@@ -212,33 +243,41 @@ def _test_fp8_mlp_tensor_parallelism_base(
212
243
sp_model ,
213
244
mesh ,
214
245
{
215
- "ffn" : PrepareFloat8ModuleInput (
246
+ "ffn" : prepare_input_cls (
216
247
input_layouts = Shard (1 ), desired_input_layouts = Replicate ()
217
248
),
218
- "ffn.w1" : Float8ColwiseParallel (),
219
- "ffn.w2" : Float8ColwiseParallel (),
220
- "ffn.out_proj" : Float8RowwiseParallel (
249
+ "ffn.w1" : colwise_parallel_cls (),
250
+ "ffn.w2" : colwise_parallel_cls (),
251
+ "ffn.out_proj" : rowwise_parallel_cls (
221
252
output_layouts = Shard (1 ), use_local_output = False
222
253
),
223
254
},
224
255
)
225
256
226
- # PrepareFloat8ModuleInput with specific submodule fqn
257
+ # prepare_input_cls with specific submodule fqn
227
258
sp_model2 = copy .deepcopy (toy_model )
228
259
sp_model2 = convert_to_float8_training (sp_model2 , config = config )
229
260
261
+ if rowwise :
262
+ prepare_input = prepare_input_cls (
263
+ input_layouts = Shard (1 ),
264
+ desired_input_layouts = Replicate (),
265
+ )
266
+ else :
267
+ prepare_input = prepare_input_cls (
268
+ input_layouts = Shard (1 ),
269
+ desired_input_layouts = Replicate (),
270
+ fwd_config_submodule_fqn = "w2" ,
271
+ )
272
+
230
273
sp_model2 = parallelize_module (
231
274
sp_model2 ,
232
275
mesh ,
233
276
{
234
- "ffn" : PrepareFloat8ModuleInput (
235
- input_layouts = Shard (1 ),
236
- desired_input_layouts = Replicate (),
237
- fwd_config_submodule_fqn = "w2" ,
238
- ),
239
- "ffn.w1" : Float8ColwiseParallel (),
240
- "ffn.w2" : Float8ColwiseParallel (),
241
- "ffn.out_proj" : Float8RowwiseParallel (
277
+ "ffn" : prepare_input ,
278
+ "ffn.w1" : colwise_parallel_cls (),
279
+ "ffn.w2" : colwise_parallel_cls (),
280
+ "ffn.out_proj" : rowwise_parallel_cls (
242
281
output_layouts = Shard (1 ), use_local_output = False
243
282
),
244
283
},
@@ -278,11 +317,13 @@ def _test_fp8_mlp_tensor_parallelism_base(
278
317
279
318
280
319
def _test_fp8_mlp_tensor_parallelism_eager (mesh : DeviceMesh , size = 16 ):
281
- _test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = False )
320
+ _test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = False , rowwise = False )
321
+ _test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = False , rowwise = True )
282
322
283
323
284
324
def _test_fp8_mlp_tensor_parallelism_compile (mesh : DeviceMesh , size = 16 ):
285
- _test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True )
325
+ _test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True , rowwise = False )
326
+ _test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True , rowwise = True )
286
327
287
328
288
329
def _test_distribute_fsdp_tensor_subclass (tp_mesh : DeviceMesh ):
0 commit comments