diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 37cf0d5..64ed87f 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -357,12 +357,9 @@ def cast_w_to_float8( ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC - if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight - else: - w_fp8 = cast_to_float8_e4m3_dynamic( - self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W - ) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + ) return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: @@ -407,8 +404,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + with torch.profiler.record_function("cast_x_to_float8"): + x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) + with torch.profiler.record_function("cast_w_to_float8"): + w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) y = torch.matmul(x_fp8, w_fp8.t()) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 4c5297c..899d93b 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -44,12 +44,13 @@ def _prepare_input_fn( input_tensor = DTensor.from_local( input_tensor, device_mesh, input_layouts, run_check=False ) - - input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, - ) # DTensor(Float8Tensor) + + with torch.profiler.record_function("colwise_cast_to_float8_e4m3_dynamic"): + input_tensor = cast_to_float8_e4m3_dynamic( + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: @@ -67,7 +68,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) + with torch.profiler.record_function("colwise_cast_to_float8_e5m2_dynamic_bw"): + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -98,11 +100,12 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, - ) # DTensor(Float8Tensor) + with torch.profiler.record_function("rowwise_cast_to_float8_e4m3_dynamic"): + input_tensor = cast_to_float8_e4m3_dynamic( + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute( @@ -119,7 +122,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) + with torch.profiler.record_function("rowwise_cast_to_float8_e5m2_dynamic_bw"): + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -196,11 +200,12 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): input, mesh, (input_layout,), run_check=False ) - dt_inp = cast_to_float8_e4m3_dynamic( - dt_inp, - self.linear_mm_config, - gemm_input_role=GemmInputRole.X, - ) # DTensor(Float8Tensor) + with torch.profiler.record_function("prepareinput_cast_to_float8_e4m3_dynamic"): + dt_inp = cast_to_float8_e4m3_dynamic( + dt_inp, + self.linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 04cd797..d228189 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: torch.ops.aten.as_strided.default, torch.ops.aten._to_copy.default, torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, } @@ -138,6 +140,10 @@ def unwrap(t): WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) ) out = func(*args, **kwargs) + if func is torch.ops.aten.split.Tensor: + # if func is torch.ops.aten.clone.default: + if torch.distributed.get_rank() == 0: + print(f"dispatched {func=}", flush=True) if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( @@ -188,12 +194,22 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): + from torch.distributed._tensor import DTensor + (data,) = all_gather_outputs (scale,) = metadata if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return + if isinstance(out, Float8Tensor): + out._scale = scale + elif isinstance(out, DTensor) and isinstance( + out._local_tensor, Float8Tensor + ): + out._local_tensor._scale = scale + else: + raise RuntimeError( + f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" + ) + return out return Float8Tensor( data, scale, diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 1cbec77..6005834 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -16,8 +16,22 @@ check_parity_no_mp, set_enable_fsdp_fp8_all_gather, ) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torch.distributed._tensor import Replicate, Shard from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import ( + distribute_tensor, + DTensor, + init_device_mesh, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -516,5 +530,135 @@ def test_delayed_scaling_inplace_update(self): self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) +class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common): + @property + def world_size(self) -> int: + return min(torch.cuda.device_count(), 4) + + def init_global_mesh(self) -> DeviceMesh: + dp_size = 2 if self.world_size > 2 else 1 + return init_device_mesh( + "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") + ) + + def parallelize( + self, module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool + ) -> nn.Module: + assert isinstance(module, Transformer), f"Requires Transformer but got {module}" + module_tp = parallelize_module(module, device_mesh, { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)), + "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)), + "norm": SequenceParallel(), + }) + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + + + "attention.wq": Float8ColwiseParallel(), + "attention.wk": Float8ColwiseParallel(), + "attention.wv": Float8ColwiseParallel(), + "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), + + "feed_forward": PrepareFloat8ModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": Float8ColwiseParallel(), + "feed_forward.w2": Float8RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": Float8ColwiseParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + # Parallelize the attention and feed forward submodules. + for layer in module_tp.layers: + layer_parallelize_plan = {} + layer_parallelize_plan["attention"] = PrepareFloat8ModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ) + # shard the RMSNorms + layer_parallelize_plan["attention_norm"] = SequenceParallel() + layer_parallelize_plan["ffn_norm"] = SequenceParallel() + layer_parallelize_plan["attention.wq"] = Float8ColwiseParallel() + layer_parallelize_plan["attention.wk"] = Float8ColwiseParallel() + layer_parallelize_plan["attention.wv"] = Float8ColwiseParallel() + layer_parallelize_plan["attention.wo"] = Float8RowwiseParallel(output_layouts=Shard(1)) + + layer_parallelize_plan["feed_forward.w1"] = ( + ColwiseParallel(input_layouts=Shard(1)) + if use_seq_parallel + else ColwiseParallel() + ) + layer_parallelize_plan["feed_forward.w2"] = Float8RowwiseParallel(output_layouts=Shard(1)) + + parallelize_module(layer, device_mesh, layer_parallelize_plan) + + # Parallelize the output submodule. If weight tying is enabled, we need to + # make sure output.weight is sharded consistently as tok_embeddings.weight, + # at the cost of the all_reduce operation using RowwiseParallel. + output_parallelize_plan = ( + ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + ) + if use_seq_parallel + else ColwiseParallel(output_layouts=Replicate()) + ) + parallelize_module(module_tp.output, device_mesh, output_parallelize_plan) + + # Manually set output.weight so that parameters and gradients are shared. + if module_tp.model_args.weight_tying: + module_tp.output.weight = module_tp.tok_embeddings.weight + + return module_tp + + @skip_if_lt_x_gpu(4) + def test_fsdp_tp( + self, + ): + enable_fsdp_fp8_all_gather = True + scaling_type_w = TensorScalingType.DYNAMIC + global_mesh = self.init_global_mesh() + _, tp_mesh = global_mesh["dp"], global_mesh["tp"] + model = self.init_transformer(weight_tying=False).cuda() + with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + swap_linear_with_float8_linear(model, scaling_type_w=scaling_type_w) + model. + loss_parallel = True + + + + + # "attention.wq": Float8ColwiseParallel + colwise_param = distribute_tensor( + model.layers[0].attention.wq.weight, tp_mesh, [Shard(0)] + ) + self.assertTrue( + isinstance(colwise_param, DTensor) + and isinstance( + colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor + ) + ) + # "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), + rowwise_param = distribute_tensor( + model.layers[0].attention.wo.weight, tp_mesh, [Shard(1)] + ) + self.assertTrue( + isinstance(rowwise_param, DTensor) + and isinstance( + rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor + ) + ) + + if __name__ == "__main__": run_tests()