|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# mypy: allow-untyped-defs |
| 8 | +# Copyright (c) Meta Platforms, Inc. and affiliates |
| 9 | +from functools import partial |
| 10 | +from typing import Any, Dict, Optional, Tuple, Union |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +from torch.distributed.tensor import ( |
| 15 | + DeviceMesh, |
| 16 | + distribute_module, |
| 17 | + distribute_tensor, |
| 18 | + DTensor, |
| 19 | + Partial, |
| 20 | + Replicate, |
| 21 | + Shard, |
| 22 | +) |
| 23 | +from torch.distributed.tensor.parallel import ParallelStyle |
| 24 | +from torch.distributed.tensor.placement_types import Placement |
| 25 | + |
| 26 | + |
| 27 | +# This is similar to PrepareModuleInput and PrepareModuleOutput, |
| 28 | +# but applies them simultaneously. |
| 29 | +class PrepareModuleInputOutput(ParallelStyle): |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + *, |
| 33 | + input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, |
| 34 | + desired_input_layouts: Optional[ |
| 35 | + Union[Placement, Tuple[Optional[Placement]]] |
| 36 | + ] = None, |
| 37 | + input_kwarg_layouts: Optional[Dict[str, Placement]] = None, |
| 38 | + desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, |
| 39 | + use_local_input: bool = True, |
| 40 | + output_layouts: Union[Placement, Tuple[Placement]], |
| 41 | + desired_output_layouts: Union[Placement, Tuple[Placement]], |
| 42 | + use_local_output: bool = True, |
| 43 | + ): |
| 44 | + # for input |
| 45 | + self.input_layouts = ( |
| 46 | + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts |
| 47 | + ) |
| 48 | + self.desired_input_layouts = ( |
| 49 | + (desired_input_layouts,) |
| 50 | + if isinstance(desired_input_layouts, Placement) |
| 51 | + else desired_input_layouts |
| 52 | + ) |
| 53 | + self.use_local_input = use_local_input |
| 54 | + if self.input_layouts is not None: |
| 55 | + assert ( |
| 56 | + self.desired_input_layouts is not None |
| 57 | + ), "desired module inputs should not be None!" |
| 58 | + assert len(self.input_layouts) == len( |
| 59 | + self.desired_input_layouts |
| 60 | + ), "input_layouts and desired_input_layouts should have same length!" |
| 61 | + self.with_kwargs = input_kwarg_layouts is not None |
| 62 | + self.input_kwarg_layouts = input_kwarg_layouts or {} |
| 63 | + self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} |
| 64 | + if self.with_kwargs: |
| 65 | + assert len(self.input_kwarg_layouts) == len( |
| 66 | + self.desired_input_kwarg_layouts |
| 67 | + ), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" |
| 68 | + |
| 69 | + # for output |
| 70 | + self.output_layouts = ( |
| 71 | + (output_layouts,) |
| 72 | + if isinstance(output_layouts, Placement) |
| 73 | + else output_layouts |
| 74 | + ) |
| 75 | + self.desired_output_layouts = ( |
| 76 | + (desired_output_layouts,) |
| 77 | + if isinstance(desired_output_layouts, Placement) |
| 78 | + else desired_output_layouts |
| 79 | + ) |
| 80 | + self.use_local_output = use_local_output |
| 81 | + assert len(self.output_layouts) == len( |
| 82 | + self.desired_output_layouts |
| 83 | + ), "output_layouts and desired_output_layouts should have same length!" |
| 84 | + |
| 85 | + def _prepare_input_arg( |
| 86 | + self, |
| 87 | + input: Any, |
| 88 | + mesh: DeviceMesh, |
| 89 | + input_layout: Optional[Placement], |
| 90 | + desired_layout: Optional[Placement], |
| 91 | + ): |
| 92 | + if input_layout is not None: |
| 93 | + if isinstance(input, DTensor): |
| 94 | + # TODO: re-enable the check once we fix the compile path |
| 95 | + # assert inp.placements[0] == input_layout |
| 96 | + dt_inp = input |
| 97 | + else: |
| 98 | + assert isinstance( |
| 99 | + input, torch.Tensor |
| 100 | + ), "expecting input to be a torch.Tensor!" |
| 101 | + dt_inp = DTensor.from_local( |
| 102 | + input, mesh, (input_layout,), run_check=False |
| 103 | + ) |
| 104 | + |
| 105 | + if desired_layout is not None and input_layout != desired_layout: |
| 106 | + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) |
| 107 | + |
| 108 | + return dt_inp.to_local() if self.use_local_input else dt_inp |
| 109 | + else: |
| 110 | + return input |
| 111 | + |
| 112 | + def _prepare_input_fn(self, inputs, device_mesh): |
| 113 | + if self.input_layouts is None: |
| 114 | + return inputs |
| 115 | + prepared_inputs = [] |
| 116 | + if not isinstance(inputs, tuple): |
| 117 | + inputs = (inputs,) |
| 118 | + if len(inputs) != len(self.input_layouts): |
| 119 | + raise ValueError("module inputs and input_layouts should have same length!") |
| 120 | + |
| 121 | + assert ( |
| 122 | + self.desired_input_layouts is not None |
| 123 | + ), "desired module inputs should not be None!" |
| 124 | + for inp, input_layout, desired_layout in zip( |
| 125 | + inputs, self.input_layouts, self.desired_input_layouts |
| 126 | + ): |
| 127 | + prepared_inputs.append( |
| 128 | + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) |
| 129 | + ) |
| 130 | + return tuple(prepared_inputs) |
| 131 | + |
| 132 | + def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): |
| 133 | + prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) |
| 134 | + prepared_kwarg_inputs = {} |
| 135 | + for kwarg_key in kwarg_inputs.keys(): |
| 136 | + kwarg_val = kwarg_inputs[kwarg_key] |
| 137 | + input_layout = self.input_kwarg_layouts.get(kwarg_key) |
| 138 | + desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) |
| 139 | + |
| 140 | + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( |
| 141 | + kwarg_val, device_mesh, input_layout, desired_input_layout |
| 142 | + ) |
| 143 | + |
| 144 | + return (prepared_arg_inputs, prepared_kwarg_inputs) |
| 145 | + |
| 146 | + def _prepare_out_fn(self, outputs, device_mesh): |
| 147 | + prepared_outputs = [] |
| 148 | + if not isinstance(outputs, tuple): |
| 149 | + outputs = (outputs,) |
| 150 | + if len(outputs) != len(self.output_layouts): |
| 151 | + raise ValueError( |
| 152 | + "module outputs and output_layouts should have same length!" |
| 153 | + ) |
| 154 | + for out, out_layout, desired_out_layout in zip( |
| 155 | + outputs, self.output_layouts, self.desired_output_layouts |
| 156 | + ): |
| 157 | + if out_layout is not None: |
| 158 | + if isinstance(out, DTensor): |
| 159 | + # TODO: re-enable the check once we fix the compile path |
| 160 | + # assert out.placements[0] == out_layout |
| 161 | + dt_out = out |
| 162 | + else: |
| 163 | + dt_out = DTensor.from_local( |
| 164 | + out, device_mesh, (out_layout,), run_check=False |
| 165 | + ) |
| 166 | + |
| 167 | + if out_layout != desired_out_layout: |
| 168 | + dt_out = dt_out.redistribute(placements=(desired_out_layout,)) |
| 169 | + prepared_outputs.append( |
| 170 | + dt_out.to_local() if self.use_local_output else dt_out |
| 171 | + ) |
| 172 | + else: |
| 173 | + prepared_outputs.append(out) |
| 174 | + if len(prepared_outputs) == 1: |
| 175 | + return prepared_outputs[0] |
| 176 | + else: |
| 177 | + return tuple(prepared_outputs) |
| 178 | + |
| 179 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 180 | + # for input |
| 181 | + if self.with_kwargs: |
| 182 | + module.register_forward_pre_hook( |
| 183 | + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( |
| 184 | + inputs, kwargs, device_mesh |
| 185 | + ), |
| 186 | + with_kwargs=True, |
| 187 | + ) # type: ignore[misc] |
| 188 | + else: |
| 189 | + module.register_forward_pre_hook( |
| 190 | + lambda _, inputs: self._prepare_input_fn(inputs, device_mesh) |
| 191 | + ) # type: ignore[misc, call-arg] |
| 192 | + |
| 193 | + # for output |
| 194 | + module.register_forward_hook( |
| 195 | + lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh) |
| 196 | + ) # type: ignore[misc, call-arg] |
| 197 | + |
| 198 | + return module |
| 199 | + |
| 200 | + |
| 201 | +class TensorParallel(ParallelStyle): |
| 202 | + def __init__( |
| 203 | + self, |
| 204 | + *, |
| 205 | + input_layouts: Optional[Placement] = None, |
| 206 | + output_layouts: Optional[Placement] = None, |
| 207 | + use_local_output: bool = True, |
| 208 | + ): |
| 209 | + super().__init__() |
| 210 | + self.input_layouts = (input_layouts or Replicate(),) |
| 211 | + self.output_layouts = (output_layouts or Partial(),) |
| 212 | + self.desired_input_layouts = (Replicate(),) |
| 213 | + self.use_local_output = use_local_output |
| 214 | + |
| 215 | + @staticmethod |
| 216 | + def _prepare_input_fn( |
| 217 | + input_layouts, desired_input_layouts, mod, inputs, device_mesh |
| 218 | + ): |
| 219 | + # TODO: figure out dynamo support for instance method and switch this to instance method |
| 220 | + |
| 221 | + # annotate module input placements/sharding with input_layouts |
| 222 | + input_tensor = inputs[0] |
| 223 | + if not isinstance(input_tensor, DTensor): |
| 224 | + input_tensor = DTensor.from_local( |
| 225 | + input_tensor, device_mesh, input_layouts, run_check=False |
| 226 | + ) |
| 227 | + |
| 228 | + if input_layouts != desired_input_layouts: |
| 229 | + input_tensor = input_tensor.redistribute( |
| 230 | + placements=desired_input_layouts, async_op=True |
| 231 | + ) |
| 232 | + return input_tensor |
| 233 | + |
| 234 | + def _partition_fn(self, name, module, device_mesh): |
| 235 | + module.register_parameter( |
| 236 | + "gate_proj", |
| 237 | + nn.Parameter(distribute_tensor(module.gate_proj, device_mesh, [Shard(2)])), |
| 238 | + ) # Column-wise sharding |
| 239 | + module.register_parameter( |
| 240 | + "down_proj", |
| 241 | + nn.Parameter(distribute_tensor(module.down_proj, device_mesh, [Shard(1)])), |
| 242 | + ) # Row-wise sharding |
| 243 | + module.register_parameter( |
| 244 | + "up_proj", |
| 245 | + nn.Parameter(distribute_tensor(module.up_proj, device_mesh, [Shard(2)])), |
| 246 | + ) # Column-wise sharding |
| 247 | + |
| 248 | + @staticmethod |
| 249 | + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): |
| 250 | + if outputs.placements != output_layouts: |
| 251 | + outputs = outputs.redistribute(placements=output_layouts, async_op=True) |
| 252 | + # back to local tensor |
| 253 | + return outputs.to_local() if use_local_output else outputs |
| 254 | + |
| 255 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 256 | + return distribute_module( |
| 257 | + module, |
| 258 | + device_mesh, |
| 259 | + self._partition_fn, |
| 260 | + partial( |
| 261 | + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts |
| 262 | + ), |
| 263 | + partial( |
| 264 | + self._prepare_output_fn, self.output_layouts, self.use_local_output |
| 265 | + ), |
| 266 | + ) |
| 267 | + |
| 268 | + |
| 269 | +class ExpertParallel(ParallelStyle): |
| 270 | + def __init__( |
| 271 | + self, |
| 272 | + *, |
| 273 | + input_layouts: Optional[Placement] = None, |
| 274 | + output_layouts: Optional[Placement] = None, |
| 275 | + use_local_output: bool = True, |
| 276 | + ): |
| 277 | + super().__init__() |
| 278 | + self.input_layouts = (input_layouts or Shard(0),) |
| 279 | + self.output_layouts = (output_layouts or Shard(0),) |
| 280 | + self.desired_input_layouts = (Shard(0),) |
| 281 | + self.use_local_output = use_local_output |
| 282 | + |
| 283 | + @staticmethod |
| 284 | + def _prepare_input_fn( |
| 285 | + input_layouts, desired_input_layouts, mod, inputs, device_mesh |
| 286 | + ): |
| 287 | + # TODO: figure out dynamo support for instance method and switch this to instance method |
| 288 | + |
| 289 | + # annotate module input placements/sharding with input_layouts |
| 290 | + input_tensor = inputs[0] |
| 291 | + if not isinstance(input_tensor, DTensor): |
| 292 | + input_tensor = DTensor.from_local( |
| 293 | + input_tensor, device_mesh, input_layouts, run_check=False |
| 294 | + ) |
| 295 | + |
| 296 | + if input_layouts != desired_input_layouts: |
| 297 | + input_tensor = input_tensor.redistribute( |
| 298 | + placements=desired_input_layouts, async_op=True |
| 299 | + ) |
| 300 | + return input_tensor |
| 301 | + |
| 302 | + def _partition_fn(self, name, module, device_mesh): |
| 303 | + # shard on the expert dimension |
| 304 | + for name, param in module.named_parameters(recurse=False): |
| 305 | + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) |
| 306 | + module.register_parameter(name, dist_param) |
| 307 | + |
| 308 | + @staticmethod |
| 309 | + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): |
| 310 | + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) |
| 311 | + if outputs.placements != output_layouts: |
| 312 | + outputs = outputs.redistribute(placements=output_layouts, async_op=True) |
| 313 | + # back to local tensor |
| 314 | + return outputs.to_local() if use_local_output else outputs |
| 315 | + |
| 316 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 317 | + return distribute_module( |
| 318 | + module, |
| 319 | + device_mesh, |
| 320 | + self._partition_fn, |
| 321 | + partial( |
| 322 | + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts |
| 323 | + ), |
| 324 | + partial( |
| 325 | + self._prepare_output_fn, self.output_layouts, self.use_local_output |
| 326 | + ), |
| 327 | + ) |
0 commit comments