From 70c3a5e2ab3df9a32b1f08dfcd62020cc23275b7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 29 Jan 2025 18:19:13 -0800 Subject: [PATCH 001/230] Add `MoERouter` --- src/olmo_core/nn/moe/router.py | 135 +++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 src/olmo_core/nn/moe/router.py diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py new file mode 100644 index 000000000..e4ad5fa05 --- /dev/null +++ b/src/olmo_core/nn/moe/router.py @@ -0,0 +1,135 @@ +from abc import abstractmethod +from typing import Any, Callable, Optional, Tuple + +import torch +import torch.nn as nn + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + del ctx + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment: Callable[ + [torch.Tensor, int], torch.Tensor +] = _UniformExpertAssignment.apply # type: ignore + + +class MoERouter(nn.Module): + """ + A base class for MoE router modules. + + :param d_model: The model dimensionality (hidden size). + :param num_experts: The total number of experts. + :param top_k: The number of experts to assign to each token. + :param jitter_eps: Controls the amount of noise added to the input during training. + :param normalize_expert_weights: The type of norm (e.g. ``2.0`` for L2 norm) to use to normalize + the expert weights. + :param uniform_expert_assignment: Force uniform assignment. Useful for benchmarking. + """ + + def __init__( + self, + *, + d_model: int, + num_experts: int, + top_k: int = 1, + jitter_eps: Optional[float] = None, + normalize_expert_weights: Optional[float] = None, + uniform_expert_assignment: bool = False, + ): + super().__init__() + self.d_model = d_model + self.num_experts = num_experts + self.top_k = top_k + self.jitter_eps = jitter_eps + self.normalize_expert_weights = normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + def jitter(self, x: torch.Tensor) -> torch.Tensor: + if self.jitter_eps is None or not self.training: + return x + else: + low = 1.0 - self.jitter_eps + high = 1.0 + self.jitter_eps + noise = torch.rand_like(x) + return x * (low + noise * (high - low)) + + def get_top_k(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.top_k, dim=-1) + + @abstractmethod + def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor: + """ + Given the input ``x`` of shape ``(*, d_model)``, compute the expert scores. + + :returns: The expert scores, shape ``(*, num_experts)``. + """ + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given the input ``x`` of shape ``(batch_size, seq_len, d_model)``, compute the + experts assignment. + + :returns: The scores of shape ``(batch_size, seq_len, num_experts)``, the expert weights + of shape ``(batch_size, seq_len, top_k)``, and the expert indices of shape + ``(batch_size, seq_len, top_k)``. + """ + # shape: (batch_size, seq_len, d_model) + x = self.jitter(x) + + # shape: (batch_size * seq_len, num_experts) + scores = self.get_expert_scores(x.view(-1, self.d_model)) + + # shape: (batch_size * seq_len, top_k) + expert_weights, expert_indices = self.get_top_k(scores) + + if self.normalize_expert_weights is not None: + expert_weights.div_( + torch.norm( + expert_weights, + p=self.normalize_expert_weights, + dim=-1, + keepdim=True, + ) + ) + + if self.uniform_expert_assignment: + expert_indices = _uniform_expert_assignment(expert_indices, self.num_experts) + + return scores, expert_weights, expert_indices + + +class MoELinearRouter(MoERouter): + """ + A simple, learned, linear router. + """ + + def __init__( + self, + *, + bias: bool = True, + dtype: torch.dtype = torch.float32, + init_device: str = "cpu", + **kwargs, + ): + super().__init__(**kwargs) + self.w_score = nn.Linear( + self.d_model, self.num_experts, bias=bias, dtype=dtype, device=init_device + ) + + def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor: + logits = self.w_score(x.view(-1, self.d_model)) + # TODO: save router logits for Z-loss + return logits.softmax(dim=-1) From 5f3994e6ded112282bee6a7e610caf2f58c87502 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 14:13:43 -0800 Subject: [PATCH 002/230] Add `MoEMLP` --- src/olmo_core/nn/moe/mlp.py | 146 +++++++++++++++++++++++++++++++++ src/olmo_core/nn/moe/router.py | 77 +++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 src/olmo_core/nn/moe/mlp.py diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py new file mode 100644 index 000000000..74ccc2a65 --- /dev/null +++ b/src/olmo_core/nn/moe/mlp.py @@ -0,0 +1,146 @@ +import warnings +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import DeviceMesh +from torch.distributed.tensor import Shard, distribute_tensor + +from ...distributed.utils import get_local_tensor +from ...exceptions import OLMoConfigurationError + +__all__ = ["MoEMLP"] + + +class _ScaleGradient(torch.autograd.Function): + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type="cuda") + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type="cuda") + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None # type: ignore + + +_scale_gradient: Callable[[torch.Tensor, float], torch.Tensor] = _ScaleGradient.apply # type: ignore + + +class MoEMLP(nn.Module): + def __init__( + self, + *, + d_model: int, + hidden_size: int, + num_experts: int, + dtype: torch.dtype = torch.float32, + init_device: str = "cpu", + ): + super().__init__() + self.d_model = d_model + self.hidden_size = hidden_size + self.num_experts = num_experts + + self.gradient_scale: Optional[float] = None + self.experts_per_rank = num_experts + + self.w1 = nn.Parameter( + torch.empty( + num_experts, + hidden_size, + d_model, + device=init_device, + dtype=dtype, + ), + ) + self.w2 = nn.Parameter( + torch.empty( + num_experts, + hidden_size, + d_model, + device=init_device, + dtype=dtype, + ), + ) + self.w3 = nn.Parameter( + torch.empty( + num_experts, + hidden_size, + d_model, + device=init_device, + dtype=dtype, + ), + ) + + self._gmm = None + + try: + import grouped_gemm # type: ignore + + self._gmm = grouped_gemm.ops.gmm + except ImportError: + warnings.warn( + "Grouped GEMM not available, so the MoE will be substantially slower. " + "Please install the 'grouped_gemm' package if possible.\n" + "https://github.com/tgale96/grouped_gemm" + ) + + def scale_grad(self, w: torch.Tensor) -> torch.Tensor: + if self.gradient_scale is None: + return w + return _scale_gradient(w, self.gradient_scale) + + def gmm( + self, x: torch.Tensor, w: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool = False + ) -> torch.Tensor: + if self._gmm is not None: + return self._gmm(x, w, batch_sizes, trans_b=trans_b) + else: + out = [] + start = 0 + for i, size in enumerate(batch_sizes.cpu().numpy()): + rhs = w[i, :, :].t() if trans_b else w[i, :, :] + out.append(x[start : start + size, :] @ rhs) + start += size + return torch.cat(out) + + def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor: + """ + Compute the expert outputs. + + :param x: The input of shape ``(total_tokens, d_model)``. + :param tokens_per_expert: Specifies how many tokens go to each expert. Should be a + 1-D ``LongTensor``. + """ + # Scale gradients and get local tensors (in case of expert parallelism). + # shape (all): (experts_per_rank, hidden_size, d_model) + w1, w2, w3 = ( + get_local_tensor(self.scale_grad(self.w1)), + get_local_tensor(self.scale_grad(self.w2)), + get_local_tensor(self.scale_grad(self.w3)), + ) + + # Compute the MLP. + x1 = self.gmm(x, w1, tokens_per_expert, trans_b=True) + x2 = self.gmm(x, w3, tokens_per_expert, trans_b=True) + x1 = F.silu(x1) * x2 + return self.gmm(x1, w2, tokens_per_expert) + + def apply_ep(self, ep_mesh: DeviceMesh): + """ + Apply expert parallelism. + """ + if self.num_experts % ep_mesh.size() != 0: + raise OLMoConfigurationError( + f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel degree ({ep_mesh.size()})." + ) + + self.experts_per_rank = self.num_experts // ep_mesh.size() + self.gradient_scale = 1.0 / ep_mesh.size() + + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) + self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) + self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)]))) diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index e4ad5fa05..a79336163 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -1,9 +1,15 @@ from abc import abstractmethod +from dataclasses import dataclass from typing import Any, Callable, Optional, Tuple import torch import torch.nn as nn +from ...config import Config, DType, StrEnum +from ...exceptions import OLMoConfigurationError + +__all__ = ["MoERouter", "MoELinearRouter", "MoERouterConfig", "MoERouterType"] + # NOTE: To enable end-to-end benchmarking without convergence we # support a flag to force the router to assign tokens uniformly @@ -23,6 +29,77 @@ def forward(ctx: Any, x: torch.Tensor, num_experts: int): ] = _UniformExpertAssignment.apply # type: ignore +class MoERouterType(StrEnum): + """ + An enumeration of the different MoE router implementations. + """ + + default = "default" + """ + ➡️ :class:`MoELinearRouter` + """ + + +@dataclass +class MoERouterConfig(Config): + """ + A configuration class for easily building any of the different MoE router modules. + """ + + name: MoERouterType = MoERouterType.default + """ + The name of the implementation. + """ + num_experts: int = 1 + top_k: int = 1 + jitter_eps: Optional[float] = None + normalize_expert_weights: Optional[float] = None + uniform_expert_assignment: bool = False + bias: bool = True + dtype: DType = DType.float32 + + def num_params(self, d_model: int) -> int: + """ + The number of params that the module will have once built. + + :param d_model: The model dimensionality. + """ + num_params = 0 + if self.name == MoERouterType.default: + num_params += d_model * self.num_experts + if self.bias: + num_params += self.num_experts + else: + raise NotImplementedError + + return num_params + + def build(self, d_model: int, *, init_device: str = "cpu") -> "MoERouter": + """ + Build the corresponding MoE router module. + + :param d_model: The model dimensionality. + :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". + """ + kwargs = self.as_dict(exclude_none=True, recurse=False) + kwargs.pop("name") + kwargs.update( + dtype=kwargs.pop("dtype").as_pt(), + d_model=d_model, + init_device=init_device, + ) + + try: + if self.name == MoERouterType.default: + return MoELinearRouter(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e + + class MoERouter(nn.Module): """ A base class for MoE router modules. From 122af0a61d5dd4ca5506ce20be06fdb19701bf9b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 14:39:00 -0800 Subject: [PATCH 003/230] update Docker build --- Makefile | 5 ++++- src/Dockerfile | 5 ++++- src/test/nn/moe/__init__.py | 0 3 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 src/test/nn/moe/__init__.py diff --git a/Makefile b/Makefile index 395ccc59a..a0bb002f8 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,8 @@ TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .) TORCH_NIGHTLY_VERSION = "2.6.0.dev20241209" TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .) TORCHAO_VERSION = "0.6.1" -MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" +GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" +MEGABLOCKS_VERSION = "megablocks @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl VERSION = $(shell python src/olmo_core/version.py) @@ -55,6 +56,7 @@ stable-image : --build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \ --build-arg TORCH_VERSION=$(TORCH_VERSION) \ --build-arg FLASH_ATTN_WHEEL=$(FLASH_ATTN_WHEEL) \ + --build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \ --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --target stable \ @@ -70,6 +72,7 @@ nightly-image : --build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \ --build-arg TORCH_VERSION=$(TORCH_VERSION) \ --build-arg FLASH_ATTN_WHEEL=$(FLASH_ATTN_WHEEL) \ + --build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \ --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --build-arg TORCH_NIGHTLY_VERSION=$(TORCH_NIGHTLY_VERSION) \ diff --git a/src/Dockerfile b/src/Dockerfile index 186cedbb3..c3ab82b22 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -27,7 +27,10 @@ RUN pip install --upgrade --no-cache-dir pip wheel packaging "setuptools<70.0.0" # Build megablocks, grouped-gemm, stanford-stk ENV TORCH_CUDA_ARCH_LIST="8.0 9.0" ENV GROUPED_GEMM_CUTLASS="1" -ARG MEGABLOCKS_VERSION="megablocks[gg] @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" +ARG GROUPED_GEMM_VERSION="grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" +RUN pip wheel --no-build-isolation --no-cache-dir "${GROUPED_GEMM_VERSION}" + +ARG MEGABLOCKS_VERSION="megablocks @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" # Build flash-attn. diff --git a/src/test/nn/moe/__init__.py b/src/test/nn/moe/__init__.py new file mode 100644 index 000000000..e69de29bb From 937baf51d2ceb5b3857648ceaeb9d7449bc1fbcc Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 14:54:28 -0800 Subject: [PATCH 004/230] update MLP test --- src/test/nn/moe/mlp_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 src/test/nn/moe/mlp_test.py diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py new file mode 100644 index 000000000..fd1762af5 --- /dev/null +++ b/src/test/nn/moe/mlp_test.py @@ -0,0 +1,14 @@ +import torch + +from olmo_core.nn.moe.mlp import MoEMLP + +from ...utils import requires_gpu + + +@requires_gpu +def test_mlp(): + mlp = MoEMLP(d_model=128, hidden_size=256, num_experts=2, init_device="cuda") + x = torch.randn(5, 128, device="cuda") + tokens_per_expert = torch.tensor([3, 2], device="cuda") + out = mlp(x, tokens_per_expert) + assert out.shape == (5, 128) From 0b127d248cafdd9bc9603461e20279538823006b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 15:01:23 -0800 Subject: [PATCH 005/230] add a test with expert parallel --- src/test/nn/moe/mlp_test.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index fd1762af5..ea01dbcc5 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -1,7 +1,11 @@ import torch +import torch.distributed as dist +from torch.distributed.tensor import init_device_mesh from olmo_core.nn.moe.mlp import MoEMLP +from olmo_core.utils import get_default_device +from ...distributed.utils import requires_multi_gpu, run_distributed_test from ...utils import requires_gpu @@ -12,3 +16,19 @@ def test_mlp(): tokens_per_expert = torch.tensor([3, 2], device="cuda") out = mlp(x, tokens_per_expert) assert out.shape == (5, 128) + + +def run_mlp_with_expert_parallelism(): + mlp = MoEMLP(d_model=128, hidden_size=256, num_experts=4, init_device="meta") + ep_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) + mlp.apply_ep(ep_mesh) + mlp.to_empty(device=get_default_device()) + x = torch.randn(5, 128, device="cuda") + tokens_per_expert = torch.tensor([3, 2], device="cuda") + out = mlp(x, tokens_per_expert) + assert out.shape == (5, 128) + + +@requires_multi_gpu +def test_mlp_with_expert_parallelism(): + run_distributed_test(run_mlp_with_expert_parallelism, backend="nccl", start_method="spawn") From 9746d9a236f54aabf64bfac60b3cc4f226de1646 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 16:16:16 -0800 Subject: [PATCH 006/230] clean up test --- src/test/nn/moe/mlp_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index ea01dbcc5..aa95d3bdd 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -19,13 +19,18 @@ def test_mlp(): def run_mlp_with_expert_parallelism(): - mlp = MoEMLP(d_model=128, hidden_size=256, num_experts=4, init_device="meta") ep_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) + + mlp = MoEMLP( + d_model=128, hidden_size=256, num_experts=dist.get_world_size() * 2, init_device="meta" + ) mlp.apply_ep(ep_mesh) mlp.to_empty(device=get_default_device()) + x = torch.randn(5, 128, device="cuda") tokens_per_expert = torch.tensor([3, 2], device="cuda") out = mlp(x, tokens_per_expert) + assert out.shape == (5, 128) From 1dba833c50cdb02bec2170217d63e81840759be7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 16:23:11 -0800 Subject: [PATCH 007/230] add launch script to quickly run tests --- src/scripts/beaker/launch.py | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/scripts/beaker/launch.py diff --git a/src/scripts/beaker/launch.py b/src/scripts/beaker/launch.py new file mode 100644 index 000000000..639d11b30 --- /dev/null +++ b/src/scripts/beaker/launch.py @@ -0,0 +1,38 @@ +""" +Launch a command on Beaker. +""" + +import sys +from typing import List + +from olmo_core.launch.beaker import BeakerLaunchConfig, OLMoCoreBeakerImage +from olmo_core.utils import generate_uuid, prepare_cli_environment + + +def build_config(cmd: List[str]) -> BeakerLaunchConfig: + return BeakerLaunchConfig( + name=f"olmo-core-test-{generate_uuid()[:8]}", + budget="ai2/oe-training", + cmd=cmd, + task_name="test", + workspace="ai2/OLMo-core", + beaker_image=OLMoCoreBeakerImage.stable, + clusters=[ + "ai2/jupiter-cirrascale-2", + "ai2/augusta-google-1", + "ai2/ceres-cirrascale", + ], + num_nodes=1, + num_gpus=2, + shared_filesystem=True, + ) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} CMD ...") + sys.exit(1) + + prepare_cli_environment() + + build_config(sys.argv[1:]).launch(follow=True, torchrun=False) From 90921dbea79da14e32f0026394c69ee244c5255d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 16:30:20 -0800 Subject: [PATCH 008/230] fix dtype --- src/test/nn/moe/mlp_test.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index aa95d3bdd..225eb4739 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -11,8 +11,10 @@ @requires_gpu def test_mlp(): - mlp = MoEMLP(d_model=128, hidden_size=256, num_experts=2, init_device="cuda") - x = torch.randn(5, 128, device="cuda") + mlp = MoEMLP( + d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 + ) + x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) tokens_per_expert = torch.tensor([3, 2], device="cuda") out = mlp(x, tokens_per_expert) assert out.shape == (5, 128) @@ -22,12 +24,16 @@ def run_mlp_with_expert_parallelism(): ep_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) mlp = MoEMLP( - d_model=128, hidden_size=256, num_experts=dist.get_world_size() * 2, init_device="meta" + d_model=128, + hidden_size=256, + num_experts=dist.get_world_size() * 2, + init_device="meta", + dtype=torch.bfloat16, ) mlp.apply_ep(ep_mesh) mlp.to_empty(device=get_default_device()) - x = torch.randn(5, 128, device="cuda") + x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) tokens_per_expert = torch.tensor([3, 2], device="cuda") out = mlp(x, tokens_per_expert) From 027bc5bb1081f87736191a58543694adb1f501bb Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 16:36:42 -0800 Subject: [PATCH 009/230] try no host networking --- src/olmo_core/launch/beaker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 7d234d315..e29a56bec 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -350,8 +350,9 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: command=["bash", "/olmo-core/entrypoint.sh"], replicas=self.num_nodes if self.num_nodes > 1 else None, leader_selection=self.num_nodes > 1, - host_networking=self.num_nodes > 1 - or any(["augusta" in cluster for cluster in self.clusters]), + # host_networking=self.num_nodes > 1 + # or any(["augusta" in cluster for cluster in self.clusters]), + host_networking=self.num_nodes > 1, propagate_failure=False if self.num_nodes > 1 else None, propagate_preemption=True if self.num_nodes > 1 else None, synchronized_start_timeout="90m" if self.num_nodes > 1 else None, From 68a6b0f42d86cde3350bd54f642d4c8e0af7420f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 16:40:33 -0800 Subject: [PATCH 010/230] make host-networking configurable --- src/olmo_core/launch/beaker.py | 10 +++++++--- src/scripts/beaker/launch.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index e29a56bec..5d855acfa 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -206,6 +206,8 @@ class BeakerLaunchConfig(Config): Allow running with uncommitted changed. """ + host_networking: Optional[bool] = None + # NOTE: don't assign a type here because omegaconf can't validate arbitrary classes # _beaker: Optional[Beaker] = None _beaker = None @@ -350,9 +352,11 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: command=["bash", "/olmo-core/entrypoint.sh"], replicas=self.num_nodes if self.num_nodes > 1 else None, leader_selection=self.num_nodes > 1, - # host_networking=self.num_nodes > 1 - # or any(["augusta" in cluster for cluster in self.clusters]), - host_networking=self.num_nodes > 1, + host_networking=self.host_networking + if self.host_networking is not None + else ( + self.num_nodes > 1 or any(["augusta" in cluster for cluster in self.clusters]) + ), propagate_failure=False if self.num_nodes > 1 else None, propagate_preemption=True if self.num_nodes > 1 else None, synchronized_start_timeout="90m" if self.num_nodes > 1 else None, diff --git a/src/scripts/beaker/launch.py b/src/scripts/beaker/launch.py index 639d11b30..0798a7faa 100644 --- a/src/scripts/beaker/launch.py +++ b/src/scripts/beaker/launch.py @@ -25,6 +25,7 @@ def build_config(cmd: List[str]) -> BeakerLaunchConfig: num_nodes=1, num_gpus=2, shared_filesystem=True, + host_networking=False, ) From 149e47f31da4a5a3c7c1a11f2aaea966ddfaba44 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 16:49:45 -0800 Subject: [PATCH 011/230] add config class --- src/olmo_core/nn/moe/mlp.py | 64 ++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 74ccc2a65..961cd12be 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -1,4 +1,5 @@ import warnings +from dataclasses import dataclass from typing import Any, Callable, Optional import torch @@ -7,10 +8,11 @@ from torch.distributed import DeviceMesh from torch.distributed.tensor import Shard, distribute_tensor +from ...config import Config, DType, StrEnum from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError -__all__ = ["MoEMLP"] +__all__ = ["MoEMLP", "MoEMLPConfig", "MoEMLPType"] class _ScaleGradient(torch.autograd.Function): @@ -29,7 +31,67 @@ def backward(ctx: torch.Tensor, grad: torch.Tensor): _scale_gradient: Callable[[torch.Tensor, float], torch.Tensor] = _ScaleGradient.apply # type: ignore +class MoEMLPType(StrEnum): + """ + An enumeration of the different MoE expert MLP implementations. + """ + + default = "default" + """ + ➡️ :class:`MoEMLP` + """ + + +@dataclass +class MoEMLPConfig(Config): + name: MoEMLPType = MoEMLPType.default + """ + The name of the implementation. + """ + + hidden_size: int = 1024 + num_experts: int = 1 + dtype: DType = DType.float32 + + def num_params(self, d_model: int) -> int: + """ + The number of params that the module will have once built. + + :param d_model: The model dimensionality. + """ + num_params = 0 + if self.name == MoEMLPType.default: + num_params += 3 * d_model * self.hidden_size * self.num_experts + else: + raise NotImplementedError + + return num_params + + def build(self, d_model: int, *, init_device: str = "cpu") -> "MoEMLP": + kwargs = self.as_dict(exclude_none=True, recurse=False) + kwargs.pop("name") + kwargs.update( + dtype=kwargs.pop("dtype").as_pt(), + d_model=d_model, + init_device=init_device, + ) + + try: + if self.name == MoEMLPType.default: + return MoEMLP(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e + + class MoEMLP(nn.Module): + """ + A basic expert MLP module with SwiGLU activation. + """ + def __init__( self, *, From 1dc5c12a0f6cf77c51f85ae62998a9bf69307a30 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 19:56:48 -0800 Subject: [PATCH 012/230] clean up for running tests --- src/olmo_core/launch/beaker.py | 15 +++++++++++---- src/scripts/beaker/{launch.py => launch_test.py} | 8 ++++---- 2 files changed, 15 insertions(+), 8 deletions(-) rename src/scripts/beaker/{launch.py => launch_test.py} (83%) diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 5d855acfa..0eaae09cc 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -302,7 +302,9 @@ def _create_script_dataset(self, script_name: str, script: List[str]) -> Dataset return dataset - def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: + def build_experiment_spec( + self, torchrun: bool = True, entrypoint: Optional[str] = None + ) -> ExperimentSpec: """ Get the Beaker experiment spec corresponding to this config instance. """ @@ -338,7 +340,8 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: ) entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"') else: - entrypoint_script.append('python "$@"') + entrypoint = entrypoint or "python" + entrypoint_script.append(f'{entrypoint} "$@"') entrypoint_dataset = self._create_script_dataset("entrypoint.sh", entrypoint_script) @@ -434,7 +437,9 @@ def _follow_experiment(self, experiment: Experiment): else: log.info("Experiment completed successfully") - def launch(self, follow: bool = False, torchrun: bool = True) -> Experiment: + def launch( + self, follow: bool = False, torchrun: bool = True, entrypoint: Optional[str] = None + ) -> Experiment: """ Launch a Beaker experiment using this config. @@ -444,10 +449,12 @@ def launch(self, follow: bool = False, torchrun: bool = True) -> Experiment: :param follow: Stream the logs and follow the experiment until completion. :param torchrun: Launch the target command with ``torchrun``. + :param entrypoint: Provide an optional entrypoint program if ``torchrun`` is ``False``. + Defaults to 'python'. :returns: The Beaker experiment. """ - spec = self.build_experiment_spec(torchrun=torchrun) + spec = self.build_experiment_spec(torchrun=torchrun, entrypoint=entrypoint) experiment = self.beaker.experiment.create(self.name, spec) log.info(f"Experiment submitted, see progress at {self.beaker.experiment.url(experiment)}") diff --git a/src/scripts/beaker/launch.py b/src/scripts/beaker/launch_test.py similarity index 83% rename from src/scripts/beaker/launch.py rename to src/scripts/beaker/launch_test.py index 0798a7faa..ff49468ef 100644 --- a/src/scripts/beaker/launch.py +++ b/src/scripts/beaker/launch_test.py @@ -1,5 +1,5 @@ """ -Launch a command on Beaker. +Launch tests on Beaker. """ import sys @@ -11,7 +11,7 @@ def build_config(cmd: List[str]) -> BeakerLaunchConfig: return BeakerLaunchConfig( - name=f"olmo-core-test-{generate_uuid()[:8]}", + name=f"olmo-core-pytest-{generate_uuid()[:8]}", budget="ai2/oe-training", cmd=cmd, task_name="test", @@ -31,9 +31,9 @@ def build_config(cmd: List[str]) -> BeakerLaunchConfig: if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: python {sys.argv[0]} CMD ...") + print(f"Usage: python {sys.argv[0]} [PYTEST_OPTS...]") sys.exit(1) prepare_cli_environment() - build_config(sys.argv[1:]).launch(follow=True, torchrun=False) + build_config(sys.argv[1:]).launch(follow=True, torchrun=False, entrypoint="pytest") From 72939aadda4fb8281619743e88fb973fc3b170e9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 20:03:43 -0800 Subject: [PATCH 013/230] improve script --- src/scripts/beaker/launch_test.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/scripts/beaker/launch_test.py b/src/scripts/beaker/launch_test.py index ff49468ef..0b174412f 100644 --- a/src/scripts/beaker/launch_test.py +++ b/src/scripts/beaker/launch_test.py @@ -5,15 +5,17 @@ import sys from typing import List +from rich import print + from olmo_core.launch.beaker import BeakerLaunchConfig, OLMoCoreBeakerImage from olmo_core.utils import generate_uuid, prepare_cli_environment -def build_config(cmd: List[str]) -> BeakerLaunchConfig: +def build_config(pytest_opts: List[str], overrides: List[str]) -> BeakerLaunchConfig: return BeakerLaunchConfig( name=f"olmo-core-pytest-{generate_uuid()[:8]}", budget="ai2/oe-training", - cmd=cmd, + cmd=pytest_opts, task_name="test", workspace="ai2/OLMo-core", beaker_image=OLMoCoreBeakerImage.stable, @@ -26,14 +28,20 @@ def build_config(cmd: List[str]) -> BeakerLaunchConfig: num_gpus=2, shared_filesystem=True, host_networking=False, - ) + ).merge(overrides) if __name__ == "__main__": - if len(sys.argv) < 2: - print(f"Usage: python {sys.argv[0]} [PYTEST_OPTS...]") + if len(sys.argv) < 3 or "--" not in sys.argv: + print(f"Usage: python {sys.argv[0]} [OVERRIDES...] -- [PYTEST_OPTS...] TEST_TARGET") sys.exit(1) + sep_index = sys.argv.index("--") + overrides = sys.argv[1:sep_index] + pytest_opts = sys.argv[sep_index + 1 :] + prepare_cli_environment() - build_config(sys.argv[1:]).launch(follow=True, torchrun=False, entrypoint="pytest") + config = build_config(pytest_opts, overrides) + print(config) + config.launch(follow=True, torchrun=False, entrypoint="pytest") From c7a3890225c32b8a931ceb4ca3550560d7b71a83 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 20:08:29 -0800 Subject: [PATCH 014/230] setup distributed --- src/olmo_core/distributed/utils.py | 4 ++-- src/test/distributed/utils.py | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 22801e567..fc3f2e101 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -27,7 +27,7 @@ log = logging.getLogger(__name__) -def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minutes=30)): +def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minutes=30), **kwargs): """ Initialize the distributed process group with the given backend(s) and check/set the relevant environment variables. @@ -102,7 +102,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut device = torch.device(f"cuda:{int(os.environ[OLMO_LOCAL_RANK_ENV_VAR])}") torch.cuda.set_device(device) - dist.init_process_group(backend, timeout=timeout) + dist.init_process_group(backend, timeout=timeout, **kwargs) validate_env_vars() diff --git a/src/test/distributed/utils.py b/src/test/distributed/utils.py index 01b7fa2dc..40782ed64 100644 --- a/src/test/distributed/utils.py +++ b/src/test/distributed/utils.py @@ -12,6 +12,7 @@ from olmo_core.distributed.utils import ( OLMO_LOCAL_WORLD_SIZE_ENV_VAR, OLMO_NUM_NODES_ENV_VAR, + init_distributed, is_distributed, ) @@ -112,17 +113,24 @@ def log_record_factory(*args, **kwargs) -> logging.LogRecord: log = logging.getLogger() - dist.init_process_group( + os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1") + os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size)) + + # dist.init_process_group( + # backend=backend, + # init_method=f"tcp://{primary_addr}:{primary_port}", + # world_size=world_size, + # rank=process_rank, + # timeout=datetime.timedelta(seconds=120), + # ) + init_distributed( backend=backend, + timeout=datetime.timedelta(seconds=120), init_method=f"tcp://{primary_addr}:{primary_port}", world_size=world_size, rank=process_rank, - timeout=datetime.timedelta(seconds=120), ) - os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1") - os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size)) - log.info("Starting test...") if "nccl" in backend: From 3e80cc4c58d1bd24e6869d0ff8c8dac2e8736eee Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 20:10:26 -0800 Subject: [PATCH 015/230] fix --- src/scripts/beaker/launch_test.py | 2 +- src/test/distributed/utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scripts/beaker/launch_test.py b/src/scripts/beaker/launch_test.py index 0b174412f..ba980b400 100644 --- a/src/scripts/beaker/launch_test.py +++ b/src/scripts/beaker/launch_test.py @@ -27,7 +27,7 @@ def build_config(pytest_opts: List[str], overrides: List[str]) -> BeakerLaunchCo num_nodes=1, num_gpus=2, shared_filesystem=True, - host_networking=False, + # host_networking=False, ).merge(overrides) diff --git a/src/test/distributed/utils.py b/src/test/distributed/utils.py index 40782ed64..2aaa22527 100644 --- a/src/test/distributed/utils.py +++ b/src/test/distributed/utils.py @@ -10,6 +10,7 @@ import torch.multiprocessing as mp from olmo_core.distributed.utils import ( + OLMO_LOCAL_RANK_ENV_VAR, OLMO_LOCAL_WORLD_SIZE_ENV_VAR, OLMO_NUM_NODES_ENV_VAR, init_distributed, @@ -115,6 +116,7 @@ def log_record_factory(*args, **kwargs) -> logging.LogRecord: os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1") os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size)) + os.environ.setdefault(OLMO_LOCAL_RANK_ENV_VAR, str(process_rank)) # dist.init_process_group( # backend=backend, From 1f208faf2eedfe36d855176a33f7246680a275cf Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Jan 2025 20:12:06 -0800 Subject: [PATCH 016/230] fix --- src/test/distributed/utils.py | 43 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/src/test/distributed/utils.py b/src/test/distributed/utils.py index 2aaa22527..426161f7a 100644 --- a/src/test/distributed/utils.py +++ b/src/test/distributed/utils.py @@ -94,6 +94,27 @@ def init_process( ): assert world_size > 1 + os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1") + os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size)) + os.environ.setdefault(OLMO_LOCAL_RANK_ENV_VAR, str(process_rank)) + + # dist.init_process_group( + # backend=backend, + # init_method=f"tcp://{primary_addr}:{primary_port}", + # world_size=world_size, + # rank=process_rank, + # timeout=datetime.timedelta(seconds=120), + # ) + # if "nccl" in backend: + # torch.cuda.set_device(int(process_rank)) + init_distributed( + backend=backend, + timeout=datetime.timedelta(seconds=120), + init_method=f"tcp://{primary_addr}:{primary_port}", + world_size=world_size, + rank=process_rank, + ) + old_log_record_factory = logging.getLogRecordFactory() def log_record_factory(*args, **kwargs) -> logging.LogRecord: @@ -114,30 +135,8 @@ def log_record_factory(*args, **kwargs) -> logging.LogRecord: log = logging.getLogger() - os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1") - os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size)) - os.environ.setdefault(OLMO_LOCAL_RANK_ENV_VAR, str(process_rank)) - - # dist.init_process_group( - # backend=backend, - # init_method=f"tcp://{primary_addr}:{primary_port}", - # world_size=world_size, - # rank=process_rank, - # timeout=datetime.timedelta(seconds=120), - # ) - init_distributed( - backend=backend, - timeout=datetime.timedelta(seconds=120), - init_method=f"tcp://{primary_addr}:{primary_port}", - world_size=world_size, - rank=process_rank, - ) - log.info("Starting test...") - if "nccl" in backend: - torch.cuda.set_device(int(process_rank)) - try: func(*(func_args or []), **(func_kwargs or {})) finally: From da045eaf5afd1c5d55854fabb4e6f8a5a189979f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 31 Jan 2025 11:19:44 -0800 Subject: [PATCH 017/230] Add parallel MLP implementation --- src/olmo_core/nn/moe/kernels.py | 300 +++++++++++++++++++++++ src/olmo_core/nn/moe/mlp.py | 15 +- src/olmo_core/nn/moe/ops.py | 212 +++++++++++++++++ src/olmo_core/nn/moe/parallel_mlp.py | 343 +++++++++++++++++++++++++++ src/olmo_core/nn/moe/router.py | 11 +- 5 files changed, 868 insertions(+), 13 deletions(-) create mode 100644 src/olmo_core/nn/moe/kernels.py create mode 100644 src/olmo_core/nn/moe/ops.py create mode 100644 src/olmo_core/nn/moe/parallel_mlp.py diff --git a/src/olmo_core/nn/moe/kernels.py b/src/olmo_core/nn/moe/kernels.py new file mode 100644 index 000000000..7c5ef9ab9 --- /dev/null +++ b/src/olmo_core/nn/moe/kernels.py @@ -0,0 +1,300 @@ +# Adapted from https://github.com/databricks/megablocks/blob/main/megablocks/backend/kernels.py + +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") + + +def assert_equal(a, b): + if a != b: + raise ValueError( + f"Expected dimensions to be equal but got {a} and {b}.", + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), + ], + key=["NUM_COLUMNS"], +) +@triton.jit +def _padded_copy( + a, # (tokens, hidden_size), real. + b, + indices, # (tokens * top_k), integer. + bin_ids, # (tokens * top_k), integer. + weights, # (tokens * top_k), real. + bins, # (num_experts), integer. + padded_bins, # (num_experts), integer. + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) # type: ignore + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( # type: ignore + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( # type: ignore + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), + ], + key=["NUM_COLUMNS"], +) +@triton.jit +def _padded_copy_wgrad( + x, # (tokens, top_k, hidden_size), real. + grad, # (tokens, hidden_size), real. + wgrad, # (tokens, top_k), real. + indices, # (tokens * top_k), integer. + bin_ids, # (tokens * top_k), integer. + bins, # (num_experts), integer. + padded_bins, # (num_experts), integer. + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad( + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( # type: ignore + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad( + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 961cd12be..80dc3a9a7 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -108,6 +108,7 @@ def __init__( self.gradient_scale: Optional[float] = None self.experts_per_rank = num_experts + self.hidden_sharding_degree = 1 self.w1 = nn.Parameter( torch.empty( @@ -169,13 +170,13 @@ def gmm( start += size return torch.cat(out) - def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, batch_size_per_expert: torch.Tensor) -> torch.Tensor: """ Compute the expert outputs. - :param x: The input of shape ``(total_tokens, d_model)``. - :param tokens_per_expert: Specifies how many tokens go to each expert. Should be a - 1-D ``LongTensor``. + :param x: The input of shape ``(N, d_model)``. + :param batch_size_per_expert: Specifies how many items/tokens go to each expert. Should be a + 1-D ``LongTensor`` which sums to ``N``. """ # Scale gradients and get local tensors (in case of expert parallelism). # shape (all): (experts_per_rank, hidden_size, d_model) @@ -186,10 +187,10 @@ def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Ten ) # Compute the MLP. - x1 = self.gmm(x, w1, tokens_per_expert, trans_b=True) - x2 = self.gmm(x, w3, tokens_per_expert, trans_b=True) + x1 = self.gmm(x, w1, batch_size_per_expert, trans_b=True) + x2 = self.gmm(x, w3, batch_size_per_expert, trans_b=True) x1 = F.silu(x1) * x2 - return self.gmm(x1, w2, tokens_per_expert) + return self.gmm(x1, w2, batch_size_per_expert) def apply_ep(self, ep_mesh: DeviceMesh): """ diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py new file mode 100644 index 000000000..5922f5790 --- /dev/null +++ b/src/olmo_core/nn/moe/ops.py @@ -0,0 +1,212 @@ +import functools +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from . import kernels + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, dict): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, (list, tuple)): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def autocast_fwd(fwd): + """ + Wrap a custom autograd forward function to ensure it always uses the autocast dtype. + """ + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + + return decorate_fwd + + +def autocast_bwd(bwd): + """ + Wrap a custom autograd backward function to ensure it always uses the autocast dtype. + """ + + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + + return decorate_bwd + + +class GatherOp(torch.autograd.Function): + @staticmethod + @autocast_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @autocast_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +def gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + return GatherOp.apply(x, indices, bin_ids, bins, top_k) # type: ignore + + +class ScatterOp(torch.autograd.Function): + @staticmethod + @autocast_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @autocast_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) # type: ignore + + +def repeat(x: torch.Tensor, tiling: Union[torch.Size, Tuple[int, ...]]) -> torch.Tensor: + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) + + +class AllToAllOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all( + x: torch.Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +) -> Tuple[torch.Tensor, Any]: + return AllToAllOp.apply( # type: ignore + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) + + +def sum(x: torch.Tensor, dim: int = 0) -> torch.Tensor: + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py new file mode 100644 index 000000000..d8eb4b2ca --- /dev/null +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -0,0 +1,343 @@ +# Adapted from 'https://github.com/databricks/megablocks/blob/main/megablocks/layers/moe.py' and 'dmoe.py' + +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import DeviceMesh + +from ...distributed.utils import get_world_size +from . import ops +from .mlp import MoEMLP + + +class ParallelDroplessMLP(nn.Module): + """ + Wraps an MoE MLP layer to coordinate the routing and expert parallelism. + """ + + def __init__(self, *, mlp: MoEMLP): + super().__init__() + self.mlp = mlp + self._expert_parallel_enabled: bool = False + self._ep_mesh: Optional[DeviceMesh] = None + self._ep_pg: Optional[dist.ProcessGroup] = None + + @property + def d_model(self) -> int: + return self.mlp.d_model + + @property + def num_experts(self) -> int: + return self.mlp.num_experts + + @property + def experts_per_rank(self) -> int: + return self.mlp.experts_per_rank + + @property + def hidden_sharding_degree(self) -> int: + return self.mlp.hidden_sharding_degree + + @property + def ep_world_size(self) -> int: + if self._ep_pg is not None: + return get_world_size(self._ep_pg) + else: + return 1 + + def apply_ep(self, ep_mesh: DeviceMesh): + """ + Apply expert parallelism. + """ + self.mlp.apply_ep(ep_mesh) + self._expert_parallel_enabled = True + self._ep_mesh = ep_mesh + self._ep_pg = ep_mesh.get_group() + + def forward( + self, + x: torch.Tensor, + scores: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> torch.Tensor: + """ + :param x: The input of shape ``(*, d_model)``. + :param scores: The expert scores of shape ``(N, num_experts)``, where ``N`` + typically equals ``batch_size x seq_len``. + :param expert_weights: Expert weights of shape ``(N, top_k)``. + :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. + """ + in_shape = x.size() + + # Compute the experts. + if self._expert_parallel_enabled: + x, batch_size_per_expert = self.parallel_forward_once(x, expert_weights, expert_indices) + else: + x, batch_size_per_expert = self.forward_once(x, expert_weights, expert_indices) + + del scores, batch_size_per_expert + # TODO: save load balancing loss + # if self.training and self.args.moe_loss_weight > 0: + # save_load_balancing_loss((tokens_per_expert, scores)) + + return x.view(in_shape) + + def forward_once( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param x: The input of shape ``(*, d_model)``. + :param expert_weights: Expert weights of shape ``(N, top_k)``, where ``N`` + typically equals ``batch_size x seq_len``. + :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. + """ + top_k = expert_weights.shape[-1] + + # shape: (N * top_k,) + expert_weights = expert_weights.flatten() + # shape: (N * top_k,) + expert_indices = expert_indices.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) + + out = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + top_k, + ) + + return out, tokens_per_expert + + def parallel_forward_once( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param x: The input of shape ``(*, d_model)``. + :param expert_weights: Expert weights of shape ``(N, top_k)``, where ``N`` + typically equals ``batch_size x seq_len``. + :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. + """ + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + + top_k = expert_weights.shape[-1] + + # shape: (N * top_k,) + expert_weights = expert_weights.flatten() + # shape: (N * top_k,) + expert_indices = expert_indices.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (self.hidden_sharding_degree,), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like( + repeated_tokens_per_expert, + ) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self._ep_pg, + async_op=True, + ) + assert tpe_handle is not None + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + x = ops.gather(x.view(-1, x.shape[-1]), indices, bin_ids, bins, top_k) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + + # Reshape to (ep_world_size, experts_per_rank). + repeated_tokens_per_expert = repeated_tokens_per_expert.view( + self.ep_world_size, self.experts_per_rank + ) + parallel_tokens_per_expert = parallel_tokens_per_expert.view( + self.ep_world_size, self.experts_per_rank + ) + + # TODO: can we avoid the host-device sync? + send_counts = repeated_tokens_per_expert.sum(dim=-1).cpu().tolist() + recv_counts = parallel_tokens_per_expert.sum(dim=-1).cpu().tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # TODO: Fuse this into the prior, local permutation? + x = ops.repeat(x, (self.hidden_sharding_degree, 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = ops.all_to_all( + x, + recv_counts, + send_counts, + group=self._ep_pg, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=indices.device, + ), + self.experts_per_rank, + ) + + parallel_top_expert = torch.repeat_interleave( + parallel_top_expert, + parallel_tokens_per_expert.flatten(), + output_size=tokens_received, + ) + # replicate_bins = torch.cumsum(parallel_tokens_per_expert.flatten(), 0) + # parallel_top_expert = ops.replicate( + # parallel_top_expert.unsqueeze(dim=0), + # replicate_bins, + # tokens_received, + # ).flatten() + + parallel_bin_ids, parallel_indices = torch.sort(parallel_top_expert) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = torch.cumsum(parallel_tokens_per_expert, 0) + parallel_bins = ( + parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins + ) + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + 1, + ) + + # Un-permute the tokens across the devices. + x, _ = ops.all_to_all( + parallel_x, + send_counts, + recv_counts, + group=self._ep_pg, + ) + + # Reduce along the hidden sharding to get the final outputs. + # TODO: Fuse this into the following local permutation? + x = ops.sum(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + return x, tokens_per_expert.flatten() + + def indices_and_bins( + self, expert_indices: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param expert_indices: A 1D tensor. + """ + # shape: (N,) + expert_indices = expert_indices.int() + + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # shape: (N,), (N,) + bin_ids, indices = torch.sort(expert_indices) + + # Histogram the expert ids to identify the number of + # items/tokens routed to each expert. + # shape: (num_experts,) + batch_size_per_expert = torch.histc( + expert_indices, bins=self.num_experts, min=0, max=self.num_experts - 1 + ) + + # Calculate the bin bounds for the sorted items/tokens. + # shape: (num_experts,) + bins = torch.cumsum(batch_size_per_expert, 0) + + return indices, bin_ids, bins, batch_size_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + expert_weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index a79336163..135ea6e50 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -12,7 +12,7 @@ # NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly +# support a flag to force the router to assign items/tokens uniformly # across the experts. We do this with a custom autograd operation # so that PyTorch still executes the full set of router operation. class _UniformExpertAssignment(torch.autograd.Function): @@ -106,7 +106,7 @@ class MoERouter(nn.Module): :param d_model: The model dimensionality (hidden size). :param num_experts: The total number of experts. - :param top_k: The number of experts to assign to each token. + :param top_k: The number of experts to assign to each item/token. :param jitter_eps: Controls the amount of noise added to the input during training. :param normalize_expert_weights: The type of norm (e.g. ``2.0`` for L2 norm) to use to normalize the expert weights. @@ -156,12 +156,11 @@ def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Given the input ``x`` of shape ``(batch_size, seq_len, d_model)``, compute the + Given the input ``x`` of shape ``(*, d_model)``, compute the experts assignment. - :returns: The scores of shape ``(batch_size, seq_len, num_experts)``, the expert weights - of shape ``(batch_size, seq_len, top_k)``, and the expert indices of shape - ``(batch_size, seq_len, top_k)``. + :returns: The scores of shape ``(N, num_experts)``, the expert weights + of shape ``(N, top_k)``, and the expert indices of shape ``(N, top_k)``. """ # shape: (batch_size, seq_len, d_model) x = self.jitter(x) From 0d792d7cc08f50002122e7e67c7fe27b9bc96d19 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 31 Jan 2025 13:29:14 -0800 Subject: [PATCH 018/230] add MoE base --- src/olmo_core/nn/moe/moe.py | 145 +++++++++++++++++++++++++++ src/olmo_core/nn/moe/parallel_mlp.py | 31 ++++-- src/olmo_core/nn/moe/router.py | 20 ++-- src/olmo_core/nn/moe/shared_mlp.py | 106 ++++++++++++++++++++ 4 files changed, 280 insertions(+), 22 deletions(-) create mode 100644 src/olmo_core/nn/moe/moe.py create mode 100644 src/olmo_core/nn/moe/shared_mlp.py diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py new file mode 100644 index 000000000..86b17b43e --- /dev/null +++ b/src/olmo_core/nn/moe/moe.py @@ -0,0 +1,145 @@ +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh + +from ...config import Config, StrEnum +from ...exceptions import OLMoConfigurationError +from .mlp import MoEMLP, MoEMLPConfig +from .parallel_mlp import ParallelDroplessMLP, ParallelMLP +from .router import MoERouterConfig +from .shared_mlp import SharedMLPConfig + +__all__ = ["MoEBase", "DroplessMoE", "MoEConfig", "MoEType"] + + +class MoEType(StrEnum): + """ + An enumeration of the different MoE implementations. + """ + + dropless = "dropless" + """ + ➡️ :class:`DroplessMoE` + """ + + +@dataclass +class MoEConfig(Config): + name: MoEType = MoEType.dropless + """ + The name of the implementation. + """ + router: MoERouterConfig = field(default_factory=MoERouterConfig) + mlp: MoEMLPConfig = field(default_factory=MoEMLPConfig) + shared_mlp: Optional[SharedMLPConfig] = None + lb_loss_weight: Optional[float] = None + z_loss_weight: Optional[float] = None + + def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> "MoEBase": + kwargs = self.as_dict(exclude_none=True, recurse=False) + kwargs.pop("name") + kwargs.update( + dtype=kwargs.pop("dtype").as_pt(), + d_model=d_model, + num_layers=num_layers, + init_device=init_device, + ) + + try: + if self.name == MoEType.dropless: + return DroplessMoE(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e + + +class MoEBase(nn.Module): + """ + Base class for MoE implementations. + """ + + def __init__( + self, + *, + d_model: int, + router: MoERouterConfig, + mlp: MoEMLPConfig, + num_layers: int, + shared_mlp: Optional[SharedMLPConfig] = None, + init_device: str = "cpu", + lb_loss_weight: Optional[float] = None, + z_loss_weight: Optional[float] = None, + ): + super().__init__() + self.router = router.build(d_model, init_device=init_device) + self.experts = self._init_parallel_mlp(mlp.build(d_model, init_device=init_device)) + self.shared_experts = ( + None if shared_mlp is None else shared_mlp.build(d_model, init_device=init_device) + ) + self.num_layers = num_layers + self.lb_loss_weight = lb_loss_weight + self.z_loss_weight = z_loss_weight + + @abstractmethod + @classmethod + def _init_parallel_mlp(cls, mlp: MoEMLP) -> ParallelMLP: + raise NotImplementedError + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Run the MoE on the input ``x`` of shape ``(*, d_model)``. + + :param x: The input of shape ``(*, d_model)``. + + :returns: The output of the MoE layer, the optional load-balancing loss, and the optional + router Z-loss. + """ + expert_logits, expert_weights, exper_indices = self.router(x) + out, batch_size_per_expert = self.experts(x, expert_weights, exper_indices) + if self.shared_experts is not None: + out = self.shared_experts(x, out, self.router.top_k) + + lb_loss: Optional[torch.Tensor] = None + z_loss: Optional[torch.Tensor] = None + if self.training and (self.lb_loss_weight is not None or self.z_loss_weight is not None): + expert_logits = expert_logits.float() + + # Compute load-balancing loss. + if self.lb_loss_weight is not None: + expert_scores = expert_logits.softmax(dim=-1) + total_bz = expert_scores.shape[0] + scale = (self.router.num_experts * self.lb_loss_weight) / ( + self.num_layers * total_bz * self.router.top_k + ) + lb_loss = scale * torch.dot(batch_size_per_expert, expert_scores) + + # Compute router Z-loss. + if self.z_loss_weight is not None: + z_loss = torch.logsumexp(expert_logits, dim=-1).square().mean() * self.z_loss_weight + + return out, lb_loss, z_loss + + def apply_ep(self, ep_mesh: DeviceMesh): + """ + Apply expert parallelism. + """ + self.experts.apply_ep(ep_mesh) + + +class DroplessMoE(MoEBase): + """ + A dropless MoE implementation. + """ + + @classmethod + def _init_parallel_mlp(cls, mlp: MoEMLP) -> ParallelMLP: + return ParallelDroplessMLP(mlp=mlp) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index d8eb4b2ca..ef6c7204c 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -11,8 +11,10 @@ from . import ops from .mlp import MoEMLP +__all__ = ["ParallelMLP", "ParallelDroplessMLP"] -class ParallelDroplessMLP(nn.Module): + +class ParallelMLP(nn.Module): """ Wraps an MoE MLP layer to coordinate the routing and expert parallelism. """ @@ -59,17 +61,29 @@ def apply_ep(self, ep_mesh: DeviceMesh): def forward( self, x: torch.Tensor, - scores: torch.Tensor, expert_weights: torch.Tensor, expert_indices: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param x: The input of shape ``(*, d_model)``. - :param scores: The expert scores of shape ``(N, num_experts)``, where ``N`` - typically equals ``batch_size x seq_len``. :param expert_weights: Expert weights of shape ``(N, top_k)``. :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. """ + del x, expert_weights, expert_indices + raise NotImplementedError + + +class ParallelDroplessMLP(ParallelMLP): + """ + A dropless implementation of a :class:`ParallelMLP`. + """ + + def forward( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: in_shape = x.size() # Compute the experts. @@ -78,12 +92,7 @@ def forward( else: x, batch_size_per_expert = self.forward_once(x, expert_weights, expert_indices) - del scores, batch_size_per_expert - # TODO: save load balancing loss - # if self.training and self.args.moe_loss_weight > 0: - # save_load_balancing_loss((tokens_per_expert, scores)) - - return x.view(in_shape) + return x.view(in_shape), batch_size_per_expert def forward_once( self, diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 135ea6e50..9e1a12088 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -146,11 +146,11 @@ def get_top_k(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return torch.topk(scores, self.top_k, dim=-1) @abstractmethod - def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor: + def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: """ - Given the input ``x`` of shape ``(*, d_model)``, compute the expert scores. + Given the input ``x`` of shape ``(*, d_model)``, compute the un-normalized expert scores. - :returns: The expert scores, shape ``(*, num_experts)``. + :returns: The expert logits, shape ``(*, num_experts)``. """ raise NotImplementedError @@ -159,17 +159,17 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te Given the input ``x`` of shape ``(*, d_model)``, compute the experts assignment. - :returns: The scores of shape ``(N, num_experts)``, the expert weights + :returns: The logits of shape ``(N, num_experts)``, the expert weights of shape ``(N, top_k)``, and the expert indices of shape ``(N, top_k)``. """ # shape: (batch_size, seq_len, d_model) x = self.jitter(x) # shape: (batch_size * seq_len, num_experts) - scores = self.get_expert_scores(x.view(-1, self.d_model)) + logits = self.get_expert_logits(x.view(-1, self.d_model)) # shape: (batch_size * seq_len, top_k) - expert_weights, expert_indices = self.get_top_k(scores) + expert_weights, expert_indices = self.get_top_k(logits) if self.normalize_expert_weights is not None: expert_weights.div_( @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te if self.uniform_expert_assignment: expert_indices = _uniform_expert_assignment(expert_indices, self.num_experts) - return scores, expert_weights, expert_indices + return logits, expert_weights, expert_indices class MoELinearRouter(MoERouter): @@ -205,7 +205,5 @@ def __init__( self.d_model, self.num_experts, bias=bias, dtype=dtype, device=init_device ) - def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor: - logits = self.w_score(x.view(-1, self.d_model)) - # TODO: save router logits for Z-loss - return logits.softmax(dim=-1) + def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: + return self.w_score(x.view(-1, self.d_model)) diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py new file mode 100644 index 000000000..483d348e3 --- /dev/null +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from ...config import Config, DType, StrEnum +from ...exceptions import OLMoConfigurationError +from ..feed_forward import FeedForward + +__all__ = ["SharedMLP", "SharedMLPConfig", "SharedMLPType"] + + +class SharedMLPType(StrEnum): + """ + An enumeration of the different shared MLP implementations. + """ + + default = "default" + """ + ➡️ :class:`SharedMLP` + """ + + +@dataclass +class SharedMLPConfig(Config): + """ + A config for building :class:`SharedMLP` modules. + """ + + name: SharedMLPType = SharedMLPType.default + """ + The name of the implementation. + """ + hidden_size: int = 256 + weighted_sum: bool = True + bias: Optional[bool] = None + dtype: DType = DType.float32 + + def num_params(self, d_model: int) -> int: + """ + The number of params that the module will have once built. + + :param d_model: The model dimensionality. + """ + params = 0 + + params += 3 * d_model * self.hidden_size + if self.bias: + params += 2 * self.hidden_size + d_model + + return params + + def build(self, d_model: int, *, init_device: str = "cpu") -> "SharedMLP": + """ + Build the corresponding shared MLP module. + + :param d_model: The model dimensionality. + :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". + """ + kwargs = self.as_dict(exclude_none=True) + kwargs.pop("name") + kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt()) + + try: + if self.name == SharedMLPType.default: + return SharedMLP(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e + + +class SharedMLP(nn.Module): + def __init__( + self, + *, + d_model: int, + hidden_size: int, + bias: bool = True, + weighted_sum: bool = True, + dtype: torch.dtype = torch.float32, + init_device: str = "cpu", + ): + super().__init__() + self.mlp = FeedForward( + d_model=d_model, + hidden_size=hidden_size, + bias=bias, + dtype=dtype, + init_device=init_device, + ) + self.weighted_sum = weighted_sum + + def forward(self, x: torch.Tensor, experts_out: torch.Tensor, top_k: int) -> torch.Tensor: + shared_out = self.mlp(x) + if self.weighted_sum: + # Weighted by number of experts used + n_active_experts = top_k + 1 + shared_out.div_(n_active_experts) + shared_out.add_(experts_out, alpha=top_k / n_active_experts) + else: + shared_out.add_(experts_out) + return shared_out From d1f49847766df2b057738066a8c11e17d10b4dae Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 31 Jan 2025 16:18:11 -0800 Subject: [PATCH 019/230] integrate new MoE code --- README.md | 2 +- docs/source/overview/installation.rst | 2 +- src/olmo_core/nn/moe/__init__.py | 26 +- src/olmo_core/nn/moe/config.py | 222 ------------------ src/olmo_core/nn/moe/layers.py | 97 -------- src/olmo_core/nn/moe/mlp.py | 14 +- src/olmo_core/nn/moe/moe.py | 26 +- src/olmo_core/nn/moe/router.py | 11 +- src/olmo_core/nn/moe/shared_mlp.py | 16 +- src/olmo_core/nn/transformer/__init__.py | 2 + src/olmo_core/nn/transformer/block.py | 39 +-- src/olmo_core/nn/transformer/config.py | 19 +- src/olmo_core/nn/transformer/model.py | 91 ++++++- .../train/train_module/transformer.py | 89 +++---- .../train_module/transformer_pipeline.py | 14 +- src/scripts/train/OLMoE-1B-7B.py | 23 +- src/test/nn/moe/mlp_test.py | 4 +- src/test/nn/{ => moe}/moe_test.py | 26 +- src/test/utils.py | 16 +- 19 files changed, 286 insertions(+), 453 deletions(-) delete mode 100644 src/olmo_core/nn/moe/config.py delete mode 100644 src/olmo_core/nn/moe/layers.py rename src/test/nn/{ => moe}/moe_test.py (54%) diff --git a/README.md b/README.md index 12c11688d..4f0749878 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ pip install ai2-olmo-core There are a number of optional dependencies that must be installed to use certain functionality as well, including: - [flash-attn](https://github.com/Dao-AILab/flash-attention) for flash attention and certain other fused operations. - [torchao](https://github.com/pytorch/ao) for float8 training. -- [megablocks](https://github.com/databricks/megablocks) for mixture-of-experts (MoE) models. +- [grouped_gemm](https://github.com/tgale96/grouped_gemm) for mixture-of-experts (MoE) models. The published [Docker images](https://github.com/orgs/allenai/packages?repo_name=OLMo-core) contain all core and optional dependencies, and are regularly tested on our in-house H100 clusters. But there are several things to keep in mind if you intend to use these images: diff --git a/docs/source/overview/installation.rst b/docs/source/overview/installation.rst index bcbf3ecbf..3b47e9e5f 100644 --- a/docs/source/overview/installation.rst +++ b/docs/source/overview/installation.rst @@ -12,4 +12,4 @@ There are a number of optional dependencies that must be installed to use certai - `flash-attn `_ for flash attention and certain other fused operations. - `torchao `_ for float8 training (see :mod:`olmo_core.float8`). -- `megablocks `_ for mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`). +- `grouped_gemm `_ for mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`). diff --git a/src/olmo_core/nn/moe/__init__.py b/src/olmo_core/nn/moe/__init__.py index 0b44056fc..ce3aad46f 100644 --- a/src/olmo_core/nn/moe/__init__.py +++ b/src/olmo_core/nn/moe/__init__.py @@ -1,9 +1,25 @@ """ -MoE layers. Requires `megablocks `_. +MoE layers. """ -from .config import MoEActivationFn, MoEConfig, MoEMLPImplementation, MoEType -from .handler import MoEHandler -from .layers import MoE +from .mlp import MoEMLP, MoEMLPConfig, MoEMLPType +from .moe import DroplessMoE, MoEBase, MoEConfig, MoEType +from .router import MoELinearRouter, MoERouter, MoERouterConfig, MoERouterType +from .shared_mlp import SharedMLP, SharedMLPConfig, SharedMLPType -__all__ = ["MoE", "MoEConfig", "MoEType", "MoEActivationFn", "MoEMLPImplementation", "MoEHandler"] +__all__ = [ + "MoEBase", + "DroplessMoE", + "MoEConfig", + "MoEType", + "MoEMLP", + "MoEMLPConfig", + "MoEMLPType", + "SharedMLP", + "SharedMLPConfig", + "SharedMLPType", + "MoERouter", + "MoELinearRouter", + "MoERouterConfig", + "MoERouterType", +] diff --git a/src/olmo_core/nn/moe/config.py b/src/olmo_core/nn/moe/config.py deleted file mode 100644 index cbde42fb0..000000000 --- a/src/olmo_core/nn/moe/config.py +++ /dev/null @@ -1,222 +0,0 @@ -from dataclasses import dataclass -from functools import partial -from typing import Callable - -import torch -import torch.nn.functional as F - -from olmo_core.config import Config, DType, StrEnum -from olmo_core.doc_utils import beta_feature - -from .layers import MoE as MoEWrapper - - -class MoEType(StrEnum): - """ - An enumeration of MoE layer types. - """ - - default = "default" - """ - The default version. - """ - - dropless = "dropless" - """ - The `dropless - `_ version. - """ - - -class MoEActivationFn(StrEnum): - """ - An enumeration of the different MoE activation functions available. - """ - - swiglu = "swiglu" - """ - SwiGLU. - """ - gelu = "gelu" - """ - GeLU. - """ - gelu_tanh = "gelu_tanh" - """ - GeLU with tanh approximation. - """ - relu = "relu" - """ - ReLU. - """ - - def build(self) -> Callable[[torch.Tensor], torch.Tensor]: - if self == MoEActivationFn.swiglu: - return partial(F.silu, inplace=False) - elif self == MoEActivationFn.gelu: - return partial(F.gelu, approximate="none") - elif self == MoEActivationFn.gelu_tanh: - return partial(F.gelu, approximate="tanh") - elif self == MoEActivationFn.relu: - return partial(F.relu, inplace=False) - else: - raise NotImplementedError(self) - - -class MoEMLPImplementation(StrEnum): - """ - An enumeration of the different MoE implementations. - """ - - sparse = "sparse" - """ - Sparse implementation. - """ - grouped = "grouped" - """ - Requires the `grouped GEMM - `_ package. - """ - - -@beta_feature -@dataclass -class MoEConfig(Config): - """ - Configuration class for building MoE layers. - - .. important:: - Requires `megablocks `_. - """ - - name: MoEType = MoEType.default - """ - The MoE implementation. - """ - hidden_size: int = 4096 - """ - The MLP hidden size. - """ - activation_fn: MoEActivationFn = MoEActivationFn.swiglu - """ - The activation function to use. - """ - mlp_implementation: MoEMLPImplementation = MoEMLPImplementation.sparse - """ - The MLP implementation. - """ - memory_optimized_mlp: bool = False - """ - Use the memory-optimized version of the MLP. - """ - num_experts: int = 8 - """ - The number of experts to use in the MoE block. - """ - top_k: int = 2 - """ - The number of experts to select for each token. - """ - capacity_factor: int = 1 - """ - The capacity factor to use in the MoE block. Only applies if not using :data:`MoEType.dropless`. - """ - bias: bool = True - """ - Include bias terms. - """ - loss_weight: float = 0.1 - """ - The weight to use for the MoE load balancing loss. - """ - zloss_weight: float = 0.0 - """ - Weight for MoE router z-loss where None means no router z-loss. 0.001 is a common value. - """ - zloss_in_fp32: bool = False - """ - Whether to compute the z-loss in FP32. - """ - shared_expert: bool = False - """ - Whether to have an always-used expert like in `DeepSeekMoE - `_. - """ - lbl_in_fp32: bool = False - """ - Whether to perform load balancing in FP32. - """ - num_layers: int = 1 - """ - The total number of MoE layers. - """ - dtype: DType = DType.float32 - """ - The data type for the parameters. - """ - - def num_params(self, d_model: int) -> int: - num_params = 0 - - # Router. - num_params += self.num_experts * d_model - - # Experts. - num_params += self.num_experts * (2 * d_model * self.hidden_size) - if self.name == MoEType.dropless and "glu" in self.activation_fn.lower(): - num_params += self.num_experts * d_model * self.hidden_size - - # Bias. - if self.bias: - num_params += d_model - - return num_params - - def as_megablocks_args(self, *, d_model: int, init_device: str = "cpu"): - from megablocks.layers.arguments import Arguments # type: ignore - - return Arguments( - hidden_size=d_model, - activation_fn=self.activation_fn.build(), - mlp_type="glu" if "glu" in self.activation_fn.lower() else "mlp", - mlp_impl=self.mlp_implementation, - memory_optimized_mlp=self.memory_optimized_mlp, - ffn_hidden_size=self.hidden_size, - moe_num_experts=self.num_experts, - moe_top_k=self.top_k, - moe_capacity_factor=self.capacity_factor, - moe_loss_weight=self.loss_weight, - moe_zloss_weight=self.zloss_weight, - moe_zloss_in_fp32=self.zloss_in_fp32, - moe_lbl_in_fp32=self.lbl_in_fp32, - shared_expert=self.shared_expert, - bias=self.bias, - return_bias=False, - num_layers=self.num_layers, - device=torch.device(init_device), - fp16=False, - bf16=self.dtype == DType.bfloat16, - ) - - def build(self, *, d_model: int, init_device: str = "cpu") -> MoEWrapper: - """ - Build the MoE layer. - - :param d_model: The model dimensionality. - :param init_device: The device to initialize weights on. - """ - try: - from megablocks.layers.dmoe import dMoE - from megablocks.layers.moe import MoE - except ImportError as e: - raise ImportError( - "megablocks is not installed. Please install it to use MoE layers" - ) from e - - args = self.as_megablocks_args(d_model=d_model, init_device=init_device) - if self.name == MoEType.default: - return MoEWrapper(args, MoE(args)) - elif self.name == MoEType.dropless: - return MoEWrapper(args, dMoE(args)) - else: - raise NotImplementedError(self.name) diff --git a/src/olmo_core/nn/moe/layers.py b/src/olmo_core/nn/moe/layers.py deleted file mode 100644 index 9877627a1..000000000 --- a/src/olmo_core/nn/moe/layers.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - -from olmo_core.doc_utils import beta_feature - - -@beta_feature -class MoE(nn.Module): - """ - A thin wrapper around `megablocks `_ MoE layers. - - .. tip:: - Use :class:`MoEConfig` to build instances of this module. - """ - - def __init__(self, args, inner): - super().__init__() - self.args = args - self.inner = inner - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Run the MoE on the input. - - :param x: A tensor of shape ``(batch_size, sequence_length, d_model)``. - """ - return self.inner(x) - - def get_load_balancing_loss(self) -> Optional[torch.Tensor]: - """ - Get the batched load-balancing loss from the internal buffers. - - .. important:: - This method will clear the internal buffers so can only be called once per forward pass. - """ - from megablocks.layers.moe import ( # type: ignore - batched_load_balancing_loss, - clear_load_balancing_loss, - ) - - if isinstance(lb_loss := batched_load_balancing_loss(self.args), torch.Tensor): - clear_load_balancing_loss() - return lb_loss - else: - return None - - def get_router_z_loss(self) -> Optional[torch.Tensor]: - """ - Get the batched router Z-loss from the internal buffers. - - .. important:: - This method will clear the internal buffers so can only be called once per forward pass. - """ - from megablocks.layers.router import ( # type: ignore - batched_router_zloss, - clear_router_zloss, - ) - - if self.args.moe_zloss_weight != 0 and isinstance( - (z_loss_per_layer := batched_router_zloss(self.args)), torch.Tensor - ): - z_loss = z_loss_per_layer.sum() / self.args.num_layers - clear_router_zloss() - return z_loss - else: - return None - - def get_loss(self) -> Optional[torch.Tensor]: - """ - Get the batched combined load-balancing loss and router Z-loss from the internal buffers. - - .. important:: - This method will clear the internal buffers so can only be called once per forward pass. - """ - loss: Optional[torch.Tensor] = None - if (lb_loss := self.get_load_balancing_loss()) is not None: - loss = lb_loss - - if (rz_loss := self.get_router_z_loss()) is not None: - if loss is not None: - loss += rz_loss - else: - loss = rz_loss - - return loss - - def clear_losses(self): - """ - Clear internal loss buffers. - """ - from megablocks.layers.moe import clear_load_balancing_loss # type: ignore - from megablocks.layers.router import clear_router_zloss # type: ignore - - clear_load_balancing_loss() - clear_router_zloss() diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 80dc3a9a7..42bfe40a1 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -49,30 +49,34 @@ class MoEMLPConfig(Config): The name of the implementation. """ - hidden_size: int = 1024 - num_experts: int = 1 dtype: DType = DType.float32 - def num_params(self, d_model: int) -> int: + def num_params(self, d_model: int, num_experts: int, hidden_size: int) -> int: """ The number of params that the module will have once built. :param d_model: The model dimensionality. + :param num_experts: Then number of experts. + :param hidden_size: The hidden size of each expert. """ num_params = 0 if self.name == MoEMLPType.default: - num_params += 3 * d_model * self.hidden_size * self.num_experts + num_params += 3 * d_model * hidden_size * num_experts else: raise NotImplementedError return num_params - def build(self, d_model: int, *, init_device: str = "cpu") -> "MoEMLP": + def build( + self, d_model: int, num_experts: int, hidden_size: int, *, init_device: str = "cpu" + ) -> "MoEMLP": kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( dtype=kwargs.pop("dtype").as_pt(), d_model=d_model, + num_experts=num_experts, + hidden_size=hidden_size, init_device=init_device, ) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 86b17b43e..9bcc131b1 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -33,12 +33,24 @@ class MoEConfig(Config): """ The name of the implementation. """ + num_experts: int = 1 + hidden_size: int = 256 router: MoERouterConfig = field(default_factory=MoERouterConfig) mlp: MoEMLPConfig = field(default_factory=MoEMLPConfig) shared_mlp: Optional[SharedMLPConfig] = None - lb_loss_weight: Optional[float] = None + lb_loss_weight: Optional[float] = 1.0 z_loss_weight: Optional[float] = None + def num_params(self, d_model: int) -> int: + num_params = 0 + + num_params += self.router.num_params(d_model, self.num_experts) + num_params += self.mlp.num_params(d_model, self.num_experts, self.hidden_size) + if self.shared_mlp is not None: + num_params += self.shared_mlp.num_params(d_model, self.hidden_size) + + return num_params + def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> "MoEBase": kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") @@ -69,6 +81,8 @@ def __init__( self, *, d_model: int, + num_experts: int, + hidden_size: int, router: MoERouterConfig, mlp: MoEMLPConfig, num_layers: int, @@ -78,10 +92,14 @@ def __init__( z_loss_weight: Optional[float] = None, ): super().__init__() - self.router = router.build(d_model, init_device=init_device) - self.experts = self._init_parallel_mlp(mlp.build(d_model, init_device=init_device)) + self.router = router.build(d_model, num_experts, init_device=init_device) + self.experts = self._init_parallel_mlp( + mlp.build(d_model, num_experts, hidden_size, init_device=init_device) + ) self.shared_experts = ( - None if shared_mlp is None else shared_mlp.build(d_model, init_device=init_device) + None + if shared_mlp is None + else shared_mlp.build(d_model, hidden_size, init_device=init_device) ) self.num_layers = num_layers self.lb_loss_weight = lb_loss_weight diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 9e1a12088..6579a9464 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -50,7 +50,6 @@ class MoERouterConfig(Config): """ The name of the implementation. """ - num_experts: int = 1 top_k: int = 1 jitter_eps: Optional[float] = None normalize_expert_weights: Optional[float] = None @@ -58,7 +57,7 @@ class MoERouterConfig(Config): bias: bool = True dtype: DType = DType.float32 - def num_params(self, d_model: int) -> int: + def num_params(self, d_model: int, num_experts: int) -> int: """ The number of params that the module will have once built. @@ -66,19 +65,20 @@ def num_params(self, d_model: int) -> int: """ num_params = 0 if self.name == MoERouterType.default: - num_params += d_model * self.num_experts + num_params += d_model * num_experts if self.bias: - num_params += self.num_experts + num_params += num_experts else: raise NotImplementedError return num_params - def build(self, d_model: int, *, init_device: str = "cpu") -> "MoERouter": + def build(self, d_model: int, num_experts, *, init_device: str = "cpu") -> "MoERouter": """ Build the corresponding MoE router module. :param d_model: The model dimensionality. + :param num_experts: The number of experts. :param init_device: The device initialize the parameters on, e.g. "cpu", "meta". """ kwargs = self.as_dict(exclude_none=True, recurse=False) @@ -86,6 +86,7 @@ def build(self, d_model: int, *, init_device: str = "cpu") -> "MoERouter": kwargs.update( dtype=kwargs.pop("dtype").as_pt(), d_model=d_model, + num_experts=num_experts, init_device=init_device, ) diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py index 483d348e3..e3f9523a3 100644 --- a/src/olmo_core/nn/moe/shared_mlp.py +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -32,12 +32,11 @@ class SharedMLPConfig(Config): """ The name of the implementation. """ - hidden_size: int = 256 weighted_sum: bool = True bias: Optional[bool] = None dtype: DType = DType.float32 - def num_params(self, d_model: int) -> int: + def num_params(self, d_model: int, hidden_size: int) -> int: """ The number of params that the module will have once built. @@ -45,13 +44,13 @@ def num_params(self, d_model: int) -> int: """ params = 0 - params += 3 * d_model * self.hidden_size + params += 3 * d_model * hidden_size if self.bias: - params += 2 * self.hidden_size + d_model + params += 2 * hidden_size + d_model return params - def build(self, d_model: int, *, init_device: str = "cpu") -> "SharedMLP": + def build(self, d_model: int, hidden_size: int, *, init_device: str = "cpu") -> "SharedMLP": """ Build the corresponding shared MLP module. @@ -60,7 +59,12 @@ def build(self, d_model: int, *, init_device: str = "cpu") -> "SharedMLP": """ kwargs = self.as_dict(exclude_none=True) kwargs.pop("name") - kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt()) + kwargs.update( + d_model=d_model, + hidden_size=hidden_size, + init_device=init_device, + dtype=kwargs.pop("dtype").as_pt(), + ) try: if self.name == SharedMLPType.default: diff --git a/src/olmo_core/nn/transformer/__init__.py b/src/olmo_core/nn/transformer/__init__.py index 1bd911710..cd54f699e 100644 --- a/src/olmo_core/nn/transformer/__init__.py +++ b/src/olmo_core/nn/transformer/__init__.py @@ -11,6 +11,7 @@ from .config import TransformerConfig, TransformerType from .init import InitMethod from .model import ( + MoETransformer, NormalizedTransformer, Transformer, TransformerActivationCheckpointingMode, @@ -22,6 +23,7 @@ "TransformerConfig", "Transformer", "NormalizedTransformer", + "MoETransformer", "TransformerBlockType", "TransformerBlockConfig", "TransformerBlockBase", diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 5b633b8b3..7a4634fac 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -48,7 +48,7 @@ class TransformerBlockType(StrEnum): ➡️ :class:`MoETransformerBlock` """ - moe_reordered_norm = "moe" + moe_reordered_norm = "moe_reordered_norm" """ ➡️ :class:`MoEReorderedNormTransformerBlock` """ @@ -90,6 +90,7 @@ def build( *, d_model: int, block_idx: int, + num_blocks: int, init_device: str = "cpu", cache: Optional[BufferCache] = None, ) -> "TransformerBlockBase": @@ -110,9 +111,9 @@ def build( elif self.name == TransformerBlockType.normalized: return NormalizedTransformerBlock(**kwargs) elif self.name == TransformerBlockType.moe: - return MoETransformerBlock(**kwargs) + return MoETransformerBlock(num_blocks=num_blocks, **kwargs) elif self.name == TransformerBlockType.moe_reordered_norm: - return MoEReorderedNormTransformerBlock(**kwargs) + return MoEReorderedNormTransformerBlock(num_blocks=num_blocks, **kwargs) else: raise NotImplementedError(self.name) except TypeError as e: @@ -390,6 +391,7 @@ def __init__( attention: AttentionConfig, feed_forward_moe: MoEConfig, layer_norm: LayerNormConfig, + num_blocks: int, dropout: float = 0.0, init_device: str = "cpu", cache: Optional[BufferCache] = None, @@ -399,7 +401,9 @@ def __init__( self.block_idx = block_idx self.attention = attention.build(d_model, init_device=init_device, cache=cache) self.attention_norm = layer_norm.build(d_model, init_device=init_device) - self.feed_forward_moe = feed_forward_moe.build(d_model=d_model, init_device=init_device) + self.feed_forward_moe = feed_forward_moe.build( + d_model=d_model, num_layers=num_blocks, init_device=init_device + ) self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() @@ -408,7 +412,7 @@ def forward( x: torch.Tensor, max_doc_len: Optional[int] = None, cu_doc_lens: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Run the block on the input ``x``. @@ -417,12 +421,18 @@ def forward( h = x + self.dropout( self.attention(self.attention_norm(x), max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) ) - return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) + moe_out, lb_loss, z_loss = self.feed_forward_moe(self.feed_forward_norm(h)) + return h + self.dropout(moe_out), lb_loss, z_loss + + def apply_ep(self, ep_mesh: DeviceMesh): + self.feed_forward_moe.apply_ep(ep_mesh) def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): del tp_mesh, float8_enabled - raise NotImplementedError("TP is not implemented yet for the MoE transformer block variant") + raise NotImplementedError( + f"TP is not implemented yet for the '{self.__class__.__name__}' variant" + ) def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: raise NotImplementedError @@ -440,18 +450,9 @@ def forward( x: torch.Tensor, max_doc_len: Optional[int] = None, cu_doc_lens: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: h = x + self.dropout( self.attention_norm(self.attention(x, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)) ) - return h + self.dropout(self.feed_forward_norm(self.feed_forward_moe(h))) - - def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): - del tp_mesh, float8_enabled - - raise NotImplementedError( - "TP is not implemented yet for the MoE reordered norm transformer block variant" - ) - - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - raise NotImplementedError + moe_out, lb_loss, z_loss = self.feed_forward_moe(h) + return h + self.dropout(self.feed_forward_norm(moe_out)), lb_loss, z_loss diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index feaac4b1b..090fcecfd 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -11,7 +11,7 @@ from ..rope import RoPEConfig, RoPEScalingConfig, RoPEType from .block import TransformerBlockConfig, TransformerBlockType from .init import InitMethod -from .model import NormalizedTransformer, Transformer +from .model import MoETransformer, NormalizedTransformer, Transformer log = logging.getLogger(__name__) @@ -31,6 +31,11 @@ class TransformerType(StrEnum): ➡️ :class:`NormalizedTransformer` (nGPT) """ + moe = "moe" + """ + ➡️ :class:`MoETransformer` + """ + @dataclass class TransformerConfig(Config): @@ -92,6 +97,18 @@ def build( init_device=init_device, init_seed=self.init_seed, ) + elif self.name == TransformerType.moe: + model = MoETransformer( + d_model=self.d_model, + vocab_size=self.vocab_size, + n_layers=self.n_layers, + block=self.block, + lm_head=self.lm_head, + dtype=self.dtype.as_pt(), + init_method=self.init_method, + init_device=init_device, + init_seed=self.init_seed, + ) else: raise NotImplementedError(self.name) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index e4e6d9a24..b14797bbd 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -1,6 +1,6 @@ import logging from functools import cached_property -from typing import List, Optional, Sequence, cast +from typing import List, Optional, Sequence, Tuple, cast import torch import torch.nn as nn @@ -9,6 +9,7 @@ from olmo_core.config import StrEnum from olmo_core.data.utils import get_cumulative_document_lengths from olmo_core.doc_utils import beta_feature +from olmo_core.exceptions import OLMoConfigurationError from olmo_core.utils import get_default_device from ..buffer_cache import BufferCache @@ -17,6 +18,7 @@ from ..utils import selective_checkpointing_context_fn from .block import ( MoETransformerBlock, + NormalizedTransformerBlock, TransformerBlock, TransformerBlockBase, TransformerBlockConfig, @@ -26,6 +28,7 @@ __all__ = [ "Transformer", "NormalizedTransformer", + "MoETransformer", "TransformerDataParallelWrappingStrategy", "TransformerActivationCheckpointingMode", ] @@ -110,12 +113,15 @@ def __init__( self.embeddings = nn.Embedding(vocab_size, d_model, dtype=dtype, device=init_device) self.blocks = nn.ModuleDict() for block_idx in range(n_layers): - self.blocks[str(block_idx)] = block.build( + block_ = block.build( d_model=d_model, block_idx=block_idx, + num_blocks=n_layers, init_device=init_device, cache=cache, ) + self._validate_block(block_) + self.blocks[str(block_idx)] = block_ self.lm_head = lm_head.build( d_model=d_model, vocab_size=vocab_size, init_device=init_device ) @@ -130,6 +136,13 @@ def __init__( self.num_params self.num_non_embedding_params + def _validate_block(self, block: TransformerBlockBase): + del block + + @property + def is_moe(self) -> bool: + return False + @property def device(self) -> torch.device: for p in self.parameters(): @@ -532,6 +545,12 @@ def __init__( init_seed=init_seed, ) + def _validate_block(self, block: TransformerBlockBase): + if not isinstance(block, NormalizedTransformerBlock): + raise OLMoConfigurationError( + f"'{self.__class__.__name__}' requires a '{NormalizedTransformerBlock.__name__}' block" + ) + @torch.no_grad() def init_weights( self, @@ -578,3 +597,71 @@ def apply_tp( def apply_compile(self): super().apply_compile() self.normalize_matrices = torch.compile(self.normalize_matrices) + + +@beta_feature +class MoETransformer(Transformer): + """ + An MoE transformer implementation, to be used with one of the + :class:`MoETransformerBlock` block types. + """ + + @property + def is_moe(self) -> bool: + return True + + def _validate_block(self, block: TransformerBlockBase): + if not isinstance(block, MoETransformerBlock): + raise OLMoConfigurationError( + f"'{self.__class__.__name__}' requires a '{MoETransformerBlock.__name__}' block" + ) + + def forward( + self, + input_ids: torch.Tensor, + doc_lens: Optional[torch.Tensor] = None, + max_doc_lens: Optional[Sequence[int]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Run the transformer on the token input IDs. + + :param input_ids: The token input IDs, shape ``(batch_size, seq_len)``. + :param doc_lens: Document lengths to use in attention for intra-document masking. + Shape ``(batch_size, max_docs)``. + Required together with ``max_doc_lens`` when using intra-document masking. + :param max_doc_lens: Maximum document length for each instance in the batch. + Required together with ``doc_lens`` when using intra-document masking. + + :returns: The output logits, the optional load-balancing loss, and the optional router Z-loss. + """ + max_doc_len: Optional[int] = None + cu_doc_lens: Optional[torch.Tensor] = None + if doc_lens is not None and max_doc_lens is not None: + max_doc_len = max(max_doc_lens) + cu_doc_lens = get_cumulative_document_lengths(doc_lens) + + # passthrough for non-existent layers, allows easy pipeline parallel configuration + h = self.embeddings(input_ids) if self.embeddings is not None else input_ids + + lb_losses: List[torch.Tensor] = [] + z_losses: List[torch.Tensor] = [] + for block in self.blocks.values(): + h, lb_loss, z_loss = block(h, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) + if lb_loss is not None: + lb_losses.append(lb_loss) + if z_loss is not None: + z_losses.append(z_loss) + + lb_loss = None + if lb_losses: + lb_loss = torch.stack(lb_losses).sum() / self.n_layers + + z_loss = None + if z_losses: + z_loss = torch.stack(z_losses).sum() / self.n_layers + + return self.lm_head(h) if self.lm_head is not None else h, lb_loss, z_loss + + def apply_ep(self, ep_mesh: DeviceMesh): + for block in self.blocks.values(): + cast(MoETransformerBlock, block).apply_ep(ep_mesh) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 1d0e6b66c..7916de311 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -33,7 +33,6 @@ cross_entropy_loss, fused_cross_entropy_loss, ) -from olmo_core.nn.moe import MoEHandler from olmo_core.nn.transformer import ( NormalizedTransformer, Transformer, @@ -364,10 +363,6 @@ def __init__( self.load_key_mapping = load_key_mapping self.label_ignore_index = label_ignore_index - self.moe_handler: Optional[MoEHandler] = None - if MoEHandler.has_moe(self.model): - self.moe_handler = MoEHandler(model=self.model) - @property def dp_process_group(self) -> Optional[dist.ProcessGroup]: return get_dp_process_group(self.world_mesh) @@ -505,13 +500,13 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Calculate how many tokens are going to be used in the loss. batch_num_tokens_for_loss = (batch["labels"] != self.label_ignore_index).sum() - # Update overall CE batch loss. + # Batch losses. ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) - - # Update overall Z batch loss. z_batch_loss: Optional[torch.Tensor] = None if self.z_loss_multiplier is not None: z_batch_loss = move_to_device(torch.tensor(0.0), self.device) + moe_batch_lb_loss: Optional[torch.Tensor] = None + moe_batch_z_loss: Optional[torch.Tensor] = None # Split into micro-batches. if self.rank_microbatch_size < (seq_len := batch["input_ids"].shape[1]): @@ -525,7 +520,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): for micro_batch_idx, micro_batch in enumerate(micro_batches): with self._train_microbatch_context(micro_batch_idx, num_micro_batches): # Run forward pass. - logits = self.model_forward(micro_batch) + logits, moe_lb_loss, moe_z_loss = self.model_forward(micro_batch) loss, ce_loss, z_loss = self.loss_fn( logits, micro_batch["labels"], batch_num_tokens_for_loss ) @@ -537,12 +532,20 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): z_batch_loss += z_loss # Maybe add MoE losses. - if self.moe_handler is not None: - moe_loss = self.moe_handler.get_combined_loss( - batch=batch, micro_batch=micro_batch - ) - if moe_loss is not None: - loss += moe_loss + if moe_lb_loss is not None: + loss += moe_lb_loss + moe_lb_loss = get_local_tensor(moe_lb_loss.detach()) + if moe_batch_lb_loss is None: + moe_batch_lb_loss = moe_lb_loss + else: + moe_batch_lb_loss += moe_lb_loss + if moe_z_loss is not None: + loss += moe_z_loss + moe_z_loss = get_local_tensor(moe_z_loss.detach()) + if moe_batch_z_loss is None: + moe_batch_z_loss = moe_z_loss + else: + moe_batch_z_loss += moe_z_loss # Run backward pass. loss.backward() @@ -550,33 +553,36 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): del batch # In case this helps with memory utilization. if dry_run: - self._clear_loss_buffers() return # Record loss metrics. - self.record_ce_loss(ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum) - + self.record_ce_loss(ce_batch_loss, ReduceType.mean) if self.z_loss_multiplier is not None: assert z_batch_loss is not None self.record_metric( "Z loss", - z_batch_loss / get_world_size(self.dp_process_group), - ReduceType.sum, + z_batch_loss, + ReduceType.mean, + namespace="train", + ) + if moe_batch_lb_loss is not None: + self.record_metric( + "load balancing loss", + moe_batch_lb_loss, + ReduceType.mean, + namespace="train", + ) + if moe_batch_z_loss is not None: + self.record_metric( + "router Z loss", + moe_batch_z_loss, + ReduceType.mean, namespace="train", ) - - if self.moe_handler is not None: - if (moe_lb_loss := self.moe_handler.get_lb_loss()) is not None: - self.record_metric("load balancing loss", moe_lb_loss, namespace="train") - if (moe_z_loss := self.moe_handler.get_z_loss()) is not None: - self.record_metric("router Z loss", moe_z_loss, namespace="train") if isinstance(self.optim, SkipStepOptimizer): self.optim.latest_loss = ce_batch_loss - # Lastly, clear internal loss buffers. - self._clear_loss_buffers() - def eval_batch( self, batch: Dict[str, Any], labels: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -585,13 +591,11 @@ def eval_batch( self.model.eval() with torch.no_grad(): - logits = self.model_forward(batch) + logits, _, _ = self.model_forward(batch) loss: Optional[torch.Tensor] = None if labels is not None: loss = self.eval_loss_fn(logits, labels) - self._clear_loss_buffers() - return logits, loss def optim_step(self): @@ -665,9 +669,11 @@ def optim_step(self): def zero_grads(self): self.optim.zero_grad(set_to_none=True) - def model_forward(self, batch: Dict[str, Any]) -> torch.Tensor: + def model_forward( + self, batch: Dict[str, Any] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ - Run a forward pass on a micro-batch, returning the logits and potentially the loss. + Run a forward pass on a micro-batch, returning the logits and potentially MoE losses. """ with self._model_forward_context(): # NOTE: Input sizes might be dynamic, e.g. when training with variable sequence lengths @@ -681,7 +687,7 @@ def model_forward(self, batch: Dict[str, Any]) -> torch.Tensor: # Run model forward, get logits. # shape: (batch_size, seq_len, vocab_size) - logits = self.model( + output = self.model( input_ids=batch["input_ids"], # attention_mask=micro_batch.get("attention_mask"), # attention_bias=micro_batch.get("attention_bias"), @@ -689,7 +695,14 @@ def model_forward(self, batch: Dict[str, Any]) -> torch.Tensor: max_doc_lens=batch.get("max_doc_lens"), ) - return logits + moe_lb_loss: Optional[torch.Tensor] = None + moe_z_loss: Optional[torch.Tensor] = None + if self.model.is_moe: + logits, moe_lb_loss, moe_z_loss = output + else: + logits = output + + return logits, moe_lb_loss, moe_z_loss def num_flops_per_token(self, seq_len: int) -> int: return self.model.num_flops_per_token(seq_len) @@ -711,10 +724,6 @@ def _model_forward_context(self) -> Generator[None, None, None]: stack.enter_context(torch.autocast(self.device.type, dtype=self.autocast_precision)) yield - def _clear_loss_buffers(self): - if self.moe_handler is not None: - self.moe_handler.clear_loss_buffers() - def _get_state_dict(self, sd_options: dist_cp_sd.StateDictOptions) -> Dict[str, Any]: return { "model": dist_cp_sd.get_model_state_dict(self.model, options=sd_options), diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 85327a741..33bfe290e 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -37,7 +37,6 @@ cross_entropy_loss, fused_cross_entropy_loss, ) -from olmo_core.nn.moe import MoEHandler from olmo_core.nn.transformer import NormalizedTransformer, Transformer from olmo_core.optim import OptimConfig, SkipStepOptimizer from olmo_core.optim.scheduler import Scheduler @@ -447,10 +446,9 @@ def __init__( self.load_key_mapping = load_key_mapping self.label_ignore_index = label_ignore_index - self.moe_handler: Optional[MoEHandler] = None for model in self.model_parts: - if MoEHandler.has_moe(model): - # TODO (epwalsh): need to figure out how to handle the internal MoE losses correctly. + if model.is_moe: + # TODO (epwalsh): need to handle the internal MoE losses correctly. raise NotImplementedError( "Pipeline parallelism with MoE's is currently not supported" ) @@ -683,12 +681,6 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): namespace="train", ) - if self.moe_handler is not None: - if (moe_lb_loss := self.moe_handler.get_lb_loss()) is not None: - self.record_metric("load balancing loss", moe_lb_loss, namespace="train") - if (moe_z_loss := self.moe_handler.get_z_loss()) is not None: - self.record_metric("router Z loss", moe_z_loss, namespace="train") - for optim in self.optimizers: if isinstance(optim, SkipStepOptimizer): assert self._ce_batch_loss is not None @@ -850,8 +842,6 @@ def _clear_loss_buffers(self): self._batch_num_tokens_for_loss = None self._ce_batch_loss = None self._z_batch_loss = None - if self.moe_handler is not None: - self.moe_handler.clear_loss_buffers() def _get_state_dict(self, sd_options: dist_cp_sd.StateDictOptions) -> Dict[str, Any]: return { diff --git a/src/scripts/train/OLMoE-1B-7B.py b/src/scripts/train/OLMoE-1B-7B.py index bf0ce2197..8e686f57d 100644 --- a/src/scripts/train/OLMoE-1B-7B.py +++ b/src/scripts/train/OLMoE-1B-7B.py @@ -6,8 +6,12 @@ from olmo_core.config import DType from olmo_core.distributed.parallel import DataParallelType from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.moe import MoEActivationFn, MoEConfig, MoEMLPImplementation, MoEType -from olmo_core.nn.transformer import TransformerBlockType, TransformerConfig +from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType +from olmo_core.nn.transformer import ( + TransformerBlockType, + TransformerConfig, + TransformerType, +) from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride from olmo_core.train import TrainerConfig from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback @@ -25,19 +29,16 @@ def build_model_config(common: CommonComponents) -> TransformerConfig: n_heads=16, block_name=TransformerBlockType.moe_reordered_norm, ) + model_config.name = TransformerType.moe model_config.block.feed_forward = None model_config.block.feed_forward_moe = MoEConfig( name=MoEType.dropless, - hidden_size=int(0.5 * model_config.d_model), - activation_fn=MoEActivationFn.swiglu, - mlp_implementation=MoEMLPImplementation.grouped, num_experts=64, - top_k=8, - num_layers=model_config.n_layers, - zloss_weight=0.001, - loss_weight=0.01, - bias=False, - dtype=model_config.dtype, + hidden_size=int(0.5 * model_config.d_model), + router=MoERouterConfig(top_k=8, bias=False), + mlp=MoEMLPConfig(), + lb_loss_weight=0.01, + z_loss_weight=0.001, ) return model_config diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index 225eb4739..91e7eb603 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -6,10 +6,11 @@ from olmo_core.utils import get_default_device from ...distributed.utils import requires_multi_gpu, run_distributed_test -from ...utils import requires_gpu +from ...utils import requires_gpu, requires_grouped_gemm @requires_gpu +@requires_grouped_gemm def test_mlp(): mlp = MoEMLP( d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 @@ -41,5 +42,6 @@ def run_mlp_with_expert_parallelism(): @requires_multi_gpu +@requires_grouped_gemm def test_mlp_with_expert_parallelism(): run_distributed_test(run_mlp_with_expert_parallelism, backend="nccl", start_method="spawn") diff --git a/src/test/nn/moe_test.py b/src/test/nn/moe/moe_test.py similarity index 54% rename from src/test/nn/moe_test.py rename to src/test/nn/moe/moe_test.py index 4603d3e92..c8d424e0d 100644 --- a/src/test/nn/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -2,26 +2,24 @@ import torch from olmo_core.config import DType -from olmo_core.nn.moe import MoEConfig, MoEMLPImplementation, MoEType +from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType -from ..utils import requires_gpu, requires_megablocks +from ...utils import requires_gpu, requires_grouped_gemm @requires_gpu -@requires_megablocks -@pytest.mark.parametrize("moe_type", [MoEType.default, MoEType.dropless]) -@pytest.mark.parametrize("mlp_impl", [MoEMLPImplementation.sparse, MoEMLPImplementation.grouped]) +@requires_grouped_gemm +@pytest.mark.parametrize("moe_type", [MoEType.dropless]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) -def test_moe(moe_type, mlp_impl, dtype): +def test_moe(moe_type, dtype): d_model = 128 config = MoEConfig( name=moe_type, - mlp_implementation=mlp_impl, - hidden_size=512, - num_experts=4, - dtype=DType.from_pt(dtype), + router=MoERouterConfig(num_experts=4, top_k=1, dtype=DType.from_pt(dtype)), + mlp=MoEMLPConfig(num_experts=8, hidden_size=512, dtype=DType.from_pt(dtype)), + z_loss_weight=0.1, ) - moe = config.build(d_model=d_model, init_device="cuda") + moe = config.build(d_model=d_model, num_layers=1, init_device="cuda") # Check num params calculation. num_params = 0 @@ -35,9 +33,11 @@ def test_moe(moe_type, mlp_impl, dtype): # Run forward pass. x = torch.randn(2, 16, d_model, dtype=dtype, device="cuda", requires_grad=True) - output = moe(x) + output, lb_loss, z_loss = moe(x) assert output.shape == x.shape - loss = output.sum() + moe.get_loss() + assert lb_loss is not None + assert z_loss is not None + loss = lb_loss + z_loss # Run backward pass. loss.backward() diff --git a/src/test/utils.py b/src/test/utils.py index 14a622aa7..65121ac71 100644 --- a/src/test/utils.py +++ b/src/test/utils.py @@ -8,7 +8,7 @@ has_cuda = torch.cuda.is_available() has_flash_attn = False -has_megablocks = False +has_grouped_gemm = False try: import flash_attn # type: ignore @@ -19,10 +19,10 @@ pass try: - import megablocks # type: ignore + import grouped_gemm # type: ignore - has_megablocks = True - del megablocks + has_grouped_gemm = True + del grouped_gemm except ModuleNotFoundError: pass @@ -48,14 +48,14 @@ def requires_flash_attn(func): return func -MEGABLOCKS_MARKS = ( +GROUPED_GEMM_MARKS = ( pytest.mark.gpu, - pytest.mark.skipif(not has_megablocks, reason="Requires megablocks"), + pytest.mark.skipif(not has_grouped_gemm, reason="Requires grouped_gemm"), ) -def requires_megablocks(func): - for mark in MEGABLOCKS_MARKS: +def requires_grouped_gemm(func): + for mark in GROUPED_GEMM_MARKS: func = mark(func) return func From b0155a83656f60ded8a249aa176dbfcfc49cfa5e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 31 Jan 2025 16:24:32 -0800 Subject: [PATCH 020/230] fix init --- src/olmo_core/nn/transformer/init.py | 26 +++++++++----------------- src/test/nn/moe/moe_test.py | 6 ++++-- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/olmo_core/nn/transformer/init.py b/src/olmo_core/nn/transformer/init.py index 56b55bc6f..ff5c19678 100644 --- a/src/olmo_core/nn/transformer/init.py +++ b/src/olmo_core/nn/transformer/init.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, cast import torch import torch.nn as nn @@ -7,7 +7,7 @@ from ..attention import Attention, FusedAttention from ..feed_forward import FeedForward -from ..moe import MoE +from ..moe import MoEBase, MoELinearRouter class InitMethod(StrEnum): @@ -125,7 +125,7 @@ def init_feed_forward( def init_feed_forward_moe( self, - m: MoE, + m: MoEBase, *, d_model: int, block_idx: int, @@ -140,21 +140,13 @@ def init_feed_forward_moe( elif self == InitMethod.llama_depth: std = 0.02 / (2 * (block_idx + 1)) ** 0.5 - self._init_linear(m.inner.router.layer, std=0.02, generator=generator) + self._init_linear(cast(MoELinearRouter, m.router).w_score, std=0.02, generator=generator) nn.init.trunc_normal_( - m.inner.experts.mlp.w1, mean=0.0, std=0.02, a=-3 * std, b=3 * std, generator=generator + m.experts.mlp.w1, mean=0.0, std=0.02, a=-3 * std, b=3 * std, generator=generator ) nn.init.trunc_normal_( - m.inner.experts.mlp.w2, mean=0.0, std=std, a=-3 * std, b=3 * std, generator=generator + m.experts.mlp.w2, mean=0.0, std=std, a=-3 * std, b=3 * std, generator=generator + ) + nn.init.trunc_normal_( + m.experts.mlp.w3, mean=0.0, std=std, a=-3 * std, b=3 * std, generator=generator ) - if hasattr(m.inner.experts.mlp, "v1"): - nn.init.trunc_normal_( - m.inner.experts.mlp.v1, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - generator=generator, - ) - if (bias := getattr(m.inner.experts, "bias", None)) is not None: - nn.init.zeros_(bias) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index c8d424e0d..f25723ea1 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -15,8 +15,10 @@ def test_moe(moe_type, dtype): d_model = 128 config = MoEConfig( name=moe_type, - router=MoERouterConfig(num_experts=4, top_k=1, dtype=DType.from_pt(dtype)), - mlp=MoEMLPConfig(num_experts=8, hidden_size=512, dtype=DType.from_pt(dtype)), + num_experts=4, + hidden_size=256, + router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), + mlp=MoEMLPConfig(dtype=DType.from_pt(dtype)), z_loss_weight=0.1, ) moe = config.build(d_model=d_model, num_layers=1, init_device="cuda") From cacfaa88887fc34b4774c4b9b89e9bab8992f9d5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 11:07:32 -0800 Subject: [PATCH 021/230] improve how we get MoE losses --- src/olmo_core/nn/moe/moe.py | 140 +++++++++++++++--- src/olmo_core/nn/transformer/block.py | 20 ++- src/olmo_core/nn/transformer/model.py | 52 +++---- .../train/train_module/transformer.py | 82 +++++----- 4 files changed, 187 insertions(+), 107 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 9bcc131b1..8f64f92ae 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -1,6 +1,6 @@ -from abc import abstractmethod +from abc import ABCMeta, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Tuple +from typing import Dict, List, Optional, Union import torch import torch.nn as nn @@ -72,6 +72,87 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " ) from e +class MoELoss(metaclass=ABCMeta): + @abstractmethod + def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tensor): + raise NotImplementedError + + @abstractmethod + def compute( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def reset(self): + raise NotImplementedError + + +class MoELoadBalancingLoss(MoELoss): + def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int, top_k: int): + self.loss_weight = loss_weight + self.num_layers = num_layers + self.num_experts = num_experts + self.top_k = top_k + self.loss: Optional[torch.Tensor] = None + + def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tensor): + expert_scores = expert_logits.softmax(dim=-1) + loss = torch.dot(batch_size_per_expert, expert_scores) + if self.loss is None: + self.loss = loss + else: + self.loss += loss + + def compute( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + if self.loss is None: + raise RuntimeError( + f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" + ) + scale = (self.num_experts * self.loss_weight) / (self.num_layers * total_bz * self.top_k) + lb_loss = scale * self.loss + if reset: + self.reset() + return {"load balancing loss": lb_loss} + + def reset(self): + self.loss = None + + +class MoERouterZLoss(MoELoss): + def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int): + self.loss_weight = loss_weight + self.num_layers = num_layers + self.num_experts = num_experts + self.loss: Optional[torch.Tensor] = None + + def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tensor): + del batch_size_per_expert + loss = torch.logsumexp(expert_logits, dim=-1).square().sum() + if self.loss is None: + self.loss = loss + else: + self.loss += loss + + def compute( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + if self.loss is None: + raise RuntimeError( + f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" + ) + scale = self.loss_weight / (self.num_layers * total_bz * self.num_experts) + lb_loss = scale * self.loss + if reset: + self.reset() + return {"router Z loss": lb_loss} + + def reset(self): + self.loss = None + + class MoEBase(nn.Module): """ Base class for MoE implementations. @@ -102,17 +183,41 @@ def __init__( else shared_mlp.build(d_model, hidden_size, init_device=init_device) ) self.num_layers = num_layers - self.lb_loss_weight = lb_loss_weight - self.z_loss_weight = z_loss_weight + self.losses: List[MoELoss] = [] + if lb_loss_weight is not None: + self.losses.append( + MoELoadBalancingLoss( + loss_weight=lb_loss_weight, + num_layers=num_layers, + num_experts=num_experts, + top_k=self.router.top_k, + ) + ) + if z_loss_weight is not None: + self.losses.append( + MoERouterZLoss( + loss_weight=z_loss_weight, num_layers=num_layers, num_experts=num_experts + ) + ) + + def compute_losses( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + out: Dict[str, torch.Tensor] = {} + for loss_fn in self.losses: + out.update(loss_fn.compute(total_bz, reset=reset)) + return out + + def reset_losses(self): + for loss_fn in self.losses: + loss_fn.reset() @abstractmethod @classmethod def _init_parallel_mlp(cls, mlp: MoEMLP) -> ParallelMLP: raise NotImplementedError - def forward( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Run the MoE on the input ``x`` of shape ``(*, d_model)``. @@ -126,25 +231,12 @@ def forward( if self.shared_experts is not None: out = self.shared_experts(x, out, self.router.top_k) - lb_loss: Optional[torch.Tensor] = None - z_loss: Optional[torch.Tensor] = None - if self.training and (self.lb_loss_weight is not None or self.z_loss_weight is not None): + if self.training and self.losses: expert_logits = expert_logits.float() + for loss_fn in self.losses: + loss_fn.update(expert_logits, batch_size_per_expert) - # Compute load-balancing loss. - if self.lb_loss_weight is not None: - expert_scores = expert_logits.softmax(dim=-1) - total_bz = expert_scores.shape[0] - scale = (self.router.num_experts * self.lb_loss_weight) / ( - self.num_layers * total_bz * self.router.top_k - ) - lb_loss = scale * torch.dot(batch_size_per_expert, expert_scores) - - # Compute router Z-loss. - if self.z_loss_weight is not None: - z_loss = torch.logsumexp(expert_logits, dim=-1).square().mean() * self.z_loss_weight - - return out, lb_loss, z_loss + return out def apply_ep(self, ep_mesh: DeviceMesh): """ diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 7a4634fac..000ed103e 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -1,7 +1,7 @@ import math from abc import abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -407,12 +407,20 @@ def __init__( self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + def compute_losses( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + return self.feed_forward_moe.compute_losses(total_bz, reset=reset) + + def reset_losses(self): + self.feed_forward_moe.reset_losses() + def forward( self, x: torch.Tensor, max_doc_len: Optional[int] = None, cu_doc_lens: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> torch.Tensor: """ Run the block on the input ``x``. @@ -421,8 +429,7 @@ def forward( h = x + self.dropout( self.attention(self.attention_norm(x), max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) ) - moe_out, lb_loss, z_loss = self.feed_forward_moe(self.feed_forward_norm(h)) - return h + self.dropout(moe_out), lb_loss, z_loss + return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) def apply_ep(self, ep_mesh: DeviceMesh): self.feed_forward_moe.apply_ep(ep_mesh) @@ -450,9 +457,8 @@ def forward( x: torch.Tensor, max_doc_len: Optional[int] = None, cu_doc_lens: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> torch.Tensor: h = x + self.dropout( self.attention_norm(self.attention(x, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)) ) - moe_out, lb_loss, z_loss = self.feed_forward_moe(h) - return h + self.dropout(self.feed_forward_norm(moe_out)), lb_loss, z_loss + return h + self.dropout(self.feed_forward_norm(self.feed_forward_moe(h))) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index b14797bbd..8f61edc1d 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -1,6 +1,6 @@ import logging from functools import cached_property -from typing import List, Optional, Sequence, Tuple, cast +from typing import Dict, List, Optional, Sequence, Union, cast import torch import torch.nn as nn @@ -616,24 +616,30 @@ def _validate_block(self, block: TransformerBlockBase): f"'{self.__class__.__name__}' requires a '{MoETransformerBlock.__name__}' block" ) + def compute_losses( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + out: Dict[str, torch.Tensor] = {} + for block in self.blocks.values(): + for loss_name, loss_val in ( + cast(MoETransformerBlock, block).compute_losses(total_bz, reset=reset).items() + ): + if loss_name in out: + out[loss_name] += loss_val + else: + out[loss_name] = loss_val + return out + + def reset_losses(self): + for block in self.blocks.values(): + cast(MoETransformer, block).reset_losses() + def forward( self, input_ids: torch.Tensor, doc_lens: Optional[torch.Tensor] = None, max_doc_lens: Optional[Sequence[int]] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Run the transformer on the token input IDs. - - :param input_ids: The token input IDs, shape ``(batch_size, seq_len)``. - :param doc_lens: Document lengths to use in attention for intra-document masking. - Shape ``(batch_size, max_docs)``. - Required together with ``max_doc_lens`` when using intra-document masking. - :param max_doc_lens: Maximum document length for each instance in the batch. - Required together with ``doc_lens`` when using intra-document masking. - - :returns: The output logits, the optional load-balancing loss, and the optional router Z-loss. - """ + ) -> torch.Tensor: max_doc_len: Optional[int] = None cu_doc_lens: Optional[torch.Tensor] = None if doc_lens is not None and max_doc_lens is not None: @@ -643,24 +649,10 @@ def forward( # passthrough for non-existent layers, allows easy pipeline parallel configuration h = self.embeddings(input_ids) if self.embeddings is not None else input_ids - lb_losses: List[torch.Tensor] = [] - z_losses: List[torch.Tensor] = [] for block in self.blocks.values(): - h, lb_loss, z_loss = block(h, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) - if lb_loss is not None: - lb_losses.append(lb_loss) - if z_loss is not None: - z_losses.append(z_loss) - - lb_loss = None - if lb_losses: - lb_loss = torch.stack(lb_losses).sum() / self.n_layers - - z_loss = None - if z_losses: - z_loss = torch.stack(z_losses).sum() / self.n_layers + h = block(h, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) - return self.lm_head(h) if self.lm_head is not None else h, lb_loss, z_loss + return self.lm_head(h) if self.lm_head is not None else h def apply_ep(self, ep_mesh: DeviceMesh): for block in self.blocks.values(): diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 7916de311..6ef045a30 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -34,6 +34,7 @@ fused_cross_entropy_loss, ) from olmo_core.nn.transformer import ( + MoETransformer, NormalizedTransformer, Transformer, TransformerActivationCheckpointingMode, @@ -505,8 +506,9 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): z_batch_loss: Optional[torch.Tensor] = None if self.z_loss_multiplier is not None: z_batch_loss = move_to_device(torch.tensor(0.0), self.device) - moe_batch_lb_loss: Optional[torch.Tensor] = None - moe_batch_z_loss: Optional[torch.Tensor] = None + moe_batch_losses: Optional[Dict[str, torch.Tensor]] = None + if self.model.is_moe: + moe_batch_losses = {} # Split into micro-batches. if self.rank_microbatch_size < (seq_len := batch["input_ids"].shape[1]): @@ -520,32 +522,34 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): for micro_batch_idx, micro_batch in enumerate(micro_batches): with self._train_microbatch_context(micro_batch_idx, num_micro_batches): # Run forward pass. - logits, moe_lb_loss, moe_z_loss = self.model_forward(micro_batch) + logits = self.model_forward(micro_batch) + + # Get loss to optimize for, and the separate detached CE and Z loss values. loss, ce_loss, z_loss = self.loss_fn( logits, micro_batch["labels"], batch_num_tokens_for_loss ) del logits + # Update total batch CE and Z loss. ce_batch_loss += ce_loss if z_batch_loss is not None: assert z_loss is not None z_batch_loss += z_loss - # Maybe add MoE losses. - if moe_lb_loss is not None: - loss += moe_lb_loss - moe_lb_loss = get_local_tensor(moe_lb_loss.detach()) - if moe_batch_lb_loss is None: - moe_batch_lb_loss = moe_lb_loss - else: - moe_batch_lb_loss += moe_lb_loss - if moe_z_loss is not None: - loss += moe_z_loss - moe_z_loss = get_local_tensor(moe_z_loss.detach()) - if moe_batch_z_loss is None: - moe_batch_z_loss = moe_z_loss - else: - moe_batch_z_loss += moe_z_loss + # Optionally get MoE losses and update the total batch MoE losses. + if self.model.is_moe: + assert moe_batch_losses is not None + moe_losses = cast(MoETransformer, self.model).compute_losses( + batch_num_tokens_for_loss, reset=True + ) + for loss_name, loss_val in moe_losses.items(): + loss += loss_val + loss_val = get_local_tensor(loss_val.detach()) + if loss_name in moe_batch_losses: + moe_batch_losses[loss_name] += loss_val + else: + moe_batch_losses[loss_name] = loss_val + del moe_losses # Run backward pass. loss.backward() @@ -565,20 +569,15 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): ReduceType.mean, namespace="train", ) - if moe_batch_lb_loss is not None: - self.record_metric( - "load balancing loss", - moe_batch_lb_loss, - ReduceType.mean, - namespace="train", - ) - if moe_batch_z_loss is not None: - self.record_metric( - "router Z loss", - moe_batch_z_loss, - ReduceType.mean, - namespace="train", - ) + if self.model.is_moe: + assert moe_batch_losses is not None + for loss_name, loss_val in moe_batch_losses.items(): + self.record_metric( + loss_name, + loss_val, + ReduceType.mean, + namespace="train", + ) if isinstance(self.optim, SkipStepOptimizer): self.optim.latest_loss = ce_batch_loss @@ -591,7 +590,7 @@ def eval_batch( self.model.eval() with torch.no_grad(): - logits, _, _ = self.model_forward(batch) + logits = self.model_forward(batch) loss: Optional[torch.Tensor] = None if labels is not None: loss = self.eval_loss_fn(logits, labels) @@ -669,11 +668,9 @@ def optim_step(self): def zero_grads(self): self.optim.zero_grad(set_to_none=True) - def model_forward( - self, batch: Dict[str, Any] - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + def model_forward(self, batch: Dict[str, Any]) -> torch.Tensor: """ - Run a forward pass on a micro-batch, returning the logits and potentially MoE losses. + Run a forward pass on a micro-batch, returning the logits. """ with self._model_forward_context(): # NOTE: Input sizes might be dynamic, e.g. when training with variable sequence lengths @@ -687,7 +684,7 @@ def model_forward( # Run model forward, get logits. # shape: (batch_size, seq_len, vocab_size) - output = self.model( + logits = self.model( input_ids=batch["input_ids"], # attention_mask=micro_batch.get("attention_mask"), # attention_bias=micro_batch.get("attention_bias"), @@ -695,14 +692,7 @@ def model_forward( max_doc_lens=batch.get("max_doc_lens"), ) - moe_lb_loss: Optional[torch.Tensor] = None - moe_z_loss: Optional[torch.Tensor] = None - if self.model.is_moe: - logits, moe_lb_loss, moe_z_loss = output - else: - logits = output - - return logits, moe_lb_loss, moe_z_loss + return logits def num_flops_per_token(self, seq_len: int) -> int: return self.model.num_flops_per_token(seq_len) From 055af0d39a96544d96243fc18775ef70154a8b2f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 14:44:01 -0800 Subject: [PATCH 022/230] fixes --- src/olmo_core/nn/moe/moe.py | 24 ++++++++++++------- src/olmo_core/nn/moe/parallel_mlp.py | 3 +++ .../train/train_module/transformer.py | 4 +++- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 8f64f92ae..82c7bdf88 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -74,12 +74,12 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " class MoELoss(metaclass=ABCMeta): @abstractmethod - def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tensor): + def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): raise NotImplementedError @abstractmethod def compute( - self, total_bz: Union[int, torch.Tensor], reset: bool = True + self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs ) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -96,8 +96,12 @@ def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int, top self.top_k = top_k self.loss: Optional[torch.Tensor] = None - def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tensor): + def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): + del kwargs + # shape: (N, num_experts) expert_scores = expert_logits.softmax(dim=-1) + # shape: (num_experts,) + expert_scores = expert_scores.mean(dim=0) loss = torch.dot(batch_size_per_expert, expert_scores) if self.loss is None: self.loss = loss @@ -105,8 +109,9 @@ def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tenso self.loss += loss def compute( - self, total_bz: Union[int, torch.Tensor], reset: bool = True + self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs ) -> Dict[str, torch.Tensor]: + del kwargs if self.loss is None: raise RuntimeError( f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" @@ -128,8 +133,8 @@ def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int): self.num_experts = num_experts self.loss: Optional[torch.Tensor] = None - def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tensor): - del batch_size_per_expert + def update(self, expert_logits: torch.Tensor, **kwargs): + del kwargs loss = torch.logsumexp(expert_logits, dim=-1).square().sum() if self.loss is None: self.loss = loss @@ -137,13 +142,14 @@ def update(self, expert_logits: torch.Tensor, batch_size_per_expert: torch.Tenso self.loss += loss def compute( - self, total_bz: Union[int, torch.Tensor], reset: bool = True + self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs ) -> Dict[str, torch.Tensor]: + del kwargs if self.loss is None: raise RuntimeError( f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" ) - scale = self.loss_weight / (self.num_layers * total_bz * self.num_experts) + scale = self.loss_weight / (self.num_layers * total_bz) lb_loss = scale * self.loss if reset: self.reset() @@ -234,7 +240,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and self.losses: expert_logits = expert_logits.float() for loss_fn in self.losses: - loss_fn.update(expert_logits, batch_size_per_expert) + loss_fn.update(expert_logits, batch_size_per_expert=batch_size_per_expert) return out diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index ef6c7204c..4c7a78bc0 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -68,6 +68,9 @@ def forward( :param x: The input of shape ``(*, d_model)``. :param expert_weights: Expert weights of shape ``(N, top_k)``. :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. + + :returns: The output with the same shape as ``x`` and a tensor with shape ``(num_experts,)`` + containing the number of items/tokens routed to each expert. """ del x, expert_weights, expert_indices raise NotImplementedError diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 6ef045a30..7d570c66b 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -501,7 +501,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Calculate how many tokens are going to be used in the loss. batch_num_tokens_for_loss = (batch["labels"] != self.label_ignore_index).sum() - # Batch losses. + # Batch losses to record. ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) z_batch_loss: Optional[torch.Tensor] = None if self.z_loss_multiplier is not None: @@ -532,9 +532,11 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Update total batch CE and Z loss. ce_batch_loss += ce_loss + del ce_loss if z_batch_loss is not None: assert z_loss is not None z_batch_loss += z_loss + del z_loss # Optionally get MoE losses and update the total batch MoE losses. if self.model.is_moe: From 0f9b5189d8187a893ffee3e5c2ba32a0466ef273 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 14:47:15 -0800 Subject: [PATCH 023/230] clean up --- src/olmo_core/nn/moe/moe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 82c7bdf88..99c76328b 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -89,6 +89,10 @@ def reset(self): class MoELoadBalancingLoss(MoELoss): + """ + Implements the load balancing loss from Switch Transformers. + """ + def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int, top_k: int): self.loss_weight = loss_weight self.num_layers = num_layers From d258b42f8123250772c9bb20a6e20641db48ffd5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 14:51:30 -0800 Subject: [PATCH 024/230] fix --- src/olmo_core/nn/moe/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 99c76328b..38a44971b 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -222,8 +222,8 @@ def reset_losses(self): for loss_fn in self.losses: loss_fn.reset() - @abstractmethod @classmethod + @abstractmethod def _init_parallel_mlp(cls, mlp: MoEMLP) -> ParallelMLP: raise NotImplementedError From 02fa200f7c38b00161725fbb096c343f390aa38d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:00:53 -0800 Subject: [PATCH 025/230] Add router test --- src/olmo_core/nn/moe/ops.py | 11 ++++++++--- src/test/nn/moe/router_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 src/test/nn/moe/router_test.py diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index 5922f5790..dc84eb51c 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -4,8 +4,6 @@ import torch import torch.distributed as dist -from . import kernels - def _is_eligible(x): return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) @@ -61,6 +59,8 @@ def forward( bins: torch.Tensor, top_k: int, ): + from . import kernels + ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k return kernels.gather(x, indices, bin_ids, None, bins, top_k) @@ -68,8 +68,9 @@ def forward( @staticmethod @autocast_bwd def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() + from . import kernels + grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) return out, None, None, None, None, None @@ -97,6 +98,8 @@ def forward( bins: torch.Tensor, top_k: int, ) -> torch.Tensor: + from . import kernels + maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k @@ -106,6 +109,8 @@ def forward( @staticmethod @autocast_bwd def backward(ctx: Any, grad: torch.Tensor): + from . import kernels + grad = grad.contiguous() saved_tensors = ctx.saved_tensors diff --git a/src/test/nn/moe/router_test.py b/src/test/nn/moe/router_test.py new file mode 100644 index 000000000..81f271225 --- /dev/null +++ b/src/test/nn/moe/router_test.py @@ -0,0 +1,27 @@ +import pytest +import torch + +from olmo_core.nn.moe.router import MoELinearRouter + +from ...utils import DEVICES + + +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize( + "uniform_expert_assignment", + [pytest.param(True, id="uniform"), pytest.param(False, id="computed")], +) +def test_router(device: torch.device, uniform_expert_assignment: bool): + router = MoELinearRouter( + d_model=128, + num_experts=4, + jitter_eps=0.1, + top_k=2, + normalize_expert_weights=True, + uniform_expert_assignment=uniform_expert_assignment, + ).to(device) + x = torch.randn((2, 4, 128), device=device) + logits, weights, indices = router(x) + assert logits.shape == (8, 4) + assert weights.shape == (8, 2) + assert indices.shape == (8, 2) From 31a972f04be03e42d00f135afddd07e98ba83852 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:03:40 -0800 Subject: [PATCH 026/230] fix config --- src/olmo_core/nn/moe/moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 38a44971b..5b7e81338 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -55,7 +55,6 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( - dtype=kwargs.pop("dtype").as_pt(), d_model=d_model, num_layers=num_layers, init_device=init_device, From d26755dbe6bd3575bc2078202a588609c9d41a2c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:23:06 -0800 Subject: [PATCH 027/230] fixes --- src/olmo_core/nn/moe/parallel_mlp.py | 5 +++-- src/test/nn/moe/moe_test.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 4c7a78bc0..b9d5e63b3 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -330,9 +330,10 @@ def indices_and_bins( # Calculate the bin bounds for the sorted items/tokens. # shape: (num_experts,) - bins = torch.cumsum(batch_size_per_expert, 0) + bins = torch.empty_like(batch_size_per_expert) + torch.cumsum(batch_size_per_expert, 0, out=bins) - return indices, bin_ids, bins, batch_size_per_expert + return indices.int(), bin_ids, bins, batch_size_per_expert def permute_and_compute( self, diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index f25723ea1..f066c2723 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -34,11 +34,15 @@ def test_moe(moe_type, dtype): assert config.num_params(d_model) == num_params # Run forward pass. - x = torch.randn(2, 16, d_model, dtype=dtype, device="cuda", requires_grad=True) - output, lb_loss, z_loss = moe(x) + B, S = 2, 16 + x = torch.randn(B, S, d_model, dtype=dtype, device="cuda", requires_grad=True) + + output = moe(x) assert output.shape == x.shape - assert lb_loss is not None - assert z_loss is not None + + losses = moe.compute_losses(B * S) + lb_loss = losses["load balancing loss"] + z_loss = losses["router Z loss"] loss = lb_loss + z_loss # Run backward pass. From dd910901cc7120045c90c63a763264d01ef33171 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:26:43 -0800 Subject: [PATCH 028/230] fix? --- src/olmo_core/nn/moe/parallel_mlp.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index b9d5e63b3..cff95ee48 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -313,7 +313,13 @@ def indices_and_bins( """ :param expert_indices: A 1D tensor. """ - # shape: (N,) + # Histogram the expert ids to identify the number of + # items/tokens routed to each expert. + # shape: (num_experts,), LongTensor + batch_size_per_expert = torch.histc( + expert_indices, bins=self.num_experts, min=0, max=self.num_experts - 1 + ) + expert_indices = expert_indices.int() # Sort the expert ids to produce the scatter/gather @@ -321,16 +327,9 @@ def indices_and_bins( # shape: (N,), (N,) bin_ids, indices = torch.sort(expert_indices) - # Histogram the expert ids to identify the number of - # items/tokens routed to each expert. - # shape: (num_experts,) - batch_size_per_expert = torch.histc( - expert_indices, bins=self.num_experts, min=0, max=self.num_experts - 1 - ) - # Calculate the bin bounds for the sorted items/tokens. # shape: (num_experts,) - bins = torch.empty_like(batch_size_per_expert) + bins = torch.empty_like(batch_size_per_expert, dtype=torch.int32) torch.cumsum(batch_size_per_expert, 0, out=bins) return indices.int(), bin_ids, bins, batch_size_per_expert From 05026cd91c825c07a945a3ed3344723b7df47107 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:28:48 -0800 Subject: [PATCH 029/230] fix loss --- src/olmo_core/nn/moe/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 5b7e81338..37c963cfb 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -105,7 +105,7 @@ def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Te expert_scores = expert_logits.softmax(dim=-1) # shape: (num_experts,) expert_scores = expert_scores.mean(dim=0) - loss = torch.dot(batch_size_per_expert, expert_scores) + loss = torch.dot(batch_size_per_expert.type_as(expert_scores), expert_scores) if self.loss is None: self.loss = loss else: From 925bf24dd669fb49a7af3f7aae7f5cc53d709277 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:34:55 -0800 Subject: [PATCH 030/230] Add test with expert parallelism --- src/test/nn/moe/moe_test.py | 50 +++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index f066c2723..d1e97510f 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -1,9 +1,13 @@ import pytest import torch +import torch.distributed as dist +from torch.distributed.tensor import init_device_mesh from olmo_core.config import DType from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType +from olmo_core.utils import get_default_device +from ...distributed.utils import requires_multi_gpu, run_distributed_test from ...utils import requires_gpu, requires_grouped_gemm @@ -48,3 +52,49 @@ def test_moe(moe_type, dtype): # Run backward pass. loss.backward() assert x.grad is not None + + +def run_moe_with_expert_parallelism(moe_type, dtype): + ep_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) + + d_model = 128 + config = MoEConfig( + name=moe_type, + num_experts=4, + hidden_size=256, + router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), + mlp=MoEMLPConfig(dtype=DType.from_pt(dtype)), + z_loss_weight=0.1, + ) + moe = config.build(d_model=d_model, num_layers=1, init_device="meta") + moe.apply_ep(ep_mesh) + moe.to_empty(device=get_default_device()) + + # Run forward pass. + B, S = 2, 16 + x = torch.randn(B, S, d_model, dtype=dtype, device="cuda", requires_grad=True) + + output = moe(x) + assert output.shape == x.shape + + losses = moe.compute_losses(B * S) + lb_loss = losses["load balancing loss"] + z_loss = losses["router Z loss"] + loss = lb_loss + z_loss + + # Run backward pass. + loss.backward() + assert x.grad is not None + + +@requires_multi_gpu +@requires_grouped_gemm +@pytest.mark.parametrize("moe_type", [MoEType.dropless]) +@pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) +def test_moe_with_expert_parallelism(moe_type, dtype): + run_distributed_test( + run_moe_with_expert_parallelism, + backend="nccl", + start_method="spawn", + func_args=(moe_type, dtype), + ) From 38b674a7cacf2c71822713895058adecfff63656 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:38:50 -0800 Subject: [PATCH 031/230] lol, fix --- src/olmo_core/nn/moe/ops.py | 2 +- src/olmo_core/nn/moe/parallel_mlp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index dc84eb51c..93275d7f3 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -211,7 +211,7 @@ def all_to_all( ) -def sum(x: torch.Tensor, dim: int = 0) -> torch.Tensor: +def sum_tensor(x: torch.Tensor, dim: int = 0) -> torch.Tensor: if x.shape[dim] == 1: return x.squeeze(dim=dim) return x.sum(dim=dim) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index cff95ee48..37bdb30aa 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -300,7 +300,7 @@ def parallel_forward_once( # Reduce along the hidden sharding to get the final outputs. # TODO: Fuse this into the following local permutation? - x = ops.sum(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) + x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) From daff43791c5e260785b58fb1b017a1d78ce89e38 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:47:46 -0800 Subject: [PATCH 032/230] fix dtypes? --- src/olmo_core/nn/moe/parallel_mlp.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 37bdb30aa..c1e7ada0f 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -270,12 +270,10 @@ def parallel_forward_once( # Calculate the bins boundaries from the token counts. parallel_tokens_per_expert = parallel_tokens_per_expert.sum( dim=0, - dtype=torch.int, - ) - parallel_bins = torch.cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins + dtype=torch.long, ) + parallel_bins = torch.empty_like(parallel_tokens_per_expert, dtype=torch.int32) + torch.cumsum(parallel_tokens_per_expert, 0, out=parallel_bins) # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. @@ -283,7 +281,7 @@ def parallel_forward_once( parallel_x = self.permute_and_compute( parallel_x, parallel_tokens_per_expert, - parallel_indices, + parallel_indices.int(), parallel_bin_ids, None, # expert_weights parallel_bins, From dfde49dc2cd8c482f06206b98c641e108701e766 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:57:16 -0800 Subject: [PATCH 033/230] fix some typos --- src/olmo_core/nn/moe/parallel_mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index c1e7ada0f..dd5cd6533 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -155,15 +155,13 @@ def parallel_forward_once( # its set of experts in its local HBM. # # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens + # expert assignment. After the distributed permutation the tokens # are grouped by which device they came from. We re-order them # locally to allow for efficient computation. # # After this series of permutations we compute the linear layers # and then repeat these three steps in reverse to produce the final # output. - # - # Compute the mapping of local tokens to experts. top_k = expert_weights.shape[-1] From b8bee1eedadb67a3e5cecb8c2428d561fec8bd6c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 15:59:33 -0800 Subject: [PATCH 034/230] check that loss is finite --- src/test/nn/moe/moe_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index d1e97510f..b99cd36e7 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -1,3 +1,5 @@ +import math + import pytest import torch import torch.distributed as dist @@ -46,7 +48,9 @@ def test_moe(moe_type, dtype): losses = moe.compute_losses(B * S) lb_loss = losses["load balancing loss"] + assert math.isfinite(lb_loss.item()) z_loss = losses["router Z loss"] + assert math.isfinite(z_loss.item()) loss = lb_loss + z_loss # Run backward pass. @@ -79,7 +83,9 @@ def run_moe_with_expert_parallelism(moe_type, dtype): losses = moe.compute_losses(B * S) lb_loss = losses["load balancing loss"] + assert math.isfinite(lb_loss.item()) z_loss = losses["router Z loss"] + assert math.isfinite(z_loss.item()) loss = lb_loss + z_loss # Run backward pass. From 1f869ded2546333d4a9af1b897a88398db95098f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 16:17:26 -0800 Subject: [PATCH 035/230] compute active params --- README.md | 2 +- src/olmo_core/internal/experiment.py | 4 ++-- src/olmo_core/nn/moe/mlp.py | 3 +++ src/olmo_core/nn/moe/moe.py | 7 +++++++ src/olmo_core/nn/transformer/config.py | 21 +++++++++++++++++++++ 5 files changed, 34 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4f0749878..926a2a764 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ pip install ai2-olmo-core There are a number of optional dependencies that must be installed to use certain functionality as well, including: - [flash-attn](https://github.com/Dao-AILab/flash-attention) for flash attention and certain other fused operations. - [torchao](https://github.com/pytorch/ao) for float8 training. -- [grouped_gemm](https://github.com/tgale96/grouped_gemm) for mixture-of-experts (MoE) models. +- [grouped_gemm](https://github.com/tgale96/grouped_gemm) for mixture-of-experts (MoE) models. You may need to compile from source until [PR #21](https://github.com/tgale96/grouped_gemm/pull/21) is released (post v0.1.6). The published [Docker images](https://github.com/orgs/allenai/packages?repo_name=OLMo-core) contain all core and optional dependencies, and are regularly tested on our in-house H100 clusters. But there are several things to keep in mind if you intend to use these images: diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index b99a95559..643aea9b6 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -90,8 +90,8 @@ def run(self, config: ExperimentConfig): print(config) print( "\n" - f"[b blue]Total parameters:[/] {config.model.num_params:,d}\n" - f"[b blue]Non-embedding parameters:[/] {config.model.num_non_embedding_params:,d}" + f"[b blue]Total parameters:[/] {config.model.num_params:,d} ({config.model.num_active_params:,d} active)\n" + f"[b blue]Non-embedding parameters:[/] {config.model.num_non_embedding_params:,d} ({config.model.num_active_non_embedding_params:,d} active)" ) if self == SubCmd.launch: diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 42bfe40a1..59ee5e205 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -67,6 +67,9 @@ def num_params(self, d_model: int, num_experts: int, hidden_size: int) -> int: return num_params + def num_active_params(self, d_model: int, top_k: int, hidden_size: int) -> int: + return self.num_params(d_model, top_k, hidden_size) + def build( self, d_model: int, num_experts: int, hidden_size: int, *, init_device: str = "cpu" ) -> "MoEMLP": diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 37c963cfb..6c43fe7a1 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -51,6 +51,13 @@ def num_params(self, d_model: int) -> int: return num_params + def num_active_params(self, d_model: int) -> int: + return ( + self.num_params(d_model) + - self.mlp.num_params(d_model, self.num_experts, self.hidden_size) + + self.mlp.num_active_params(d_model, self.router.top_k, self.hidden_size) + ) + def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> "MoEBase": kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 090fcecfd..2d2781d39 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -158,6 +158,20 @@ def num_params(self) -> int: return num_params + @property + def num_active_params(self) -> int: + """ + The total number of active parameters that a model from this config would have. + """ + num_params = self.num_params + if self.block.feed_forward_moe is None: + return num_params + diff_per_block = self.block.feed_forward_moe.num_params( + self.d_model + ) - self.block.feed_forward_moe.num_active_params(self.d_model) + total_diff = self.n_layers * diff_per_block + return num_params - total_diff + @property def num_non_embedding_params(self) -> int: """ @@ -165,6 +179,13 @@ def num_non_embedding_params(self) -> int: """ return self.num_params - self.d_model * self.vocab_size + @property + def num_active_non_embedding_params(self) -> int: + """ + The number of active parameters excluding embedding parameters. + """ + return self.num_active_params - self.d_model * self.vocab_size + def num_flops_per_token(self, seq_len: int) -> int: """ Get the approximate number of flops per token. From 41ccfe40dc066b676bb1b63e96ccebe2c6b43a99 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 1 Feb 2025 16:30:21 -0800 Subject: [PATCH 036/230] Allow expert parallelism --- .../distributed/parallel/__init__.py | 47 +++++++++++++++++-- .../distributed/parallel/expert_parallel.py | 15 ++++++ .../train/train_module/transformer.py | 22 ++++++++- 3 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 src/olmo_core/distributed/parallel/expert_parallel.py diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 904e874fe..2ce79b11d 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -9,6 +9,7 @@ from olmo_core.utils import get_default_device from .data_parallel import DataParallelConfig, DataParallelType, DPMeshDimName +from .expert_parallel import ExpertParallelConfig from .pipeline_parallel import ( PipelineParallelConfig, PipelineSchedule, @@ -21,12 +22,14 @@ "MeshDimName", "get_dp_mesh", "get_tp_mesh", + "get_ep_mesh", "get_pp_mesh", "get_dp_process_group", "DataParallelType", "DataParallelConfig", "DPMeshDimName", "TensorParallelConfig", + "ExpertParallelConfig", "PipelineParallelConfig", "PipelineScheduleType", "PipelineSchedule", @@ -55,6 +58,11 @@ class MeshDimName(StrEnum): The DP dimension over which the model is sharded. """ + ep = "ep" + """ + Expert parallel (EP). + """ + tp = "tp" """ Tensor parallel (TP). @@ -69,6 +77,7 @@ class MeshDimName(StrEnum): def build_device_mesh( *, dp: Optional[DataParallelConfig] = None, + ep: Optional[ExpertParallelConfig] = None, tp: Optional[TensorParallelConfig] = None, pp: Optional[PipelineParallelConfig] = None, device_type: Optional[str] = None, @@ -78,13 +87,13 @@ def build_device_mesh( The resulting dimension names will be defined in :class:`MeshDimName`. .. important:: - A data parallel config is required if either a pipeline or tensor parallel config is set. + A data parallel config is required if any other parallel config is set. """ - if pp is None and tp is None and dp is None: + if ep is None and pp is None and tp is None and dp is None: return None if dp is None: raise OLMoConfigurationError( - "Data parallel config is required in addition to tensor/pipeline parallel configs" + "Data parallel config is required in addition to expert/tensor/pipeline parallel configs" ) device_type = device_type or get_default_device().type @@ -103,6 +112,12 @@ def build_device_mesh( f"{tp.__class__.__name__}.degree must be at least 1 and divide into the world size" ) dp_world_size //= tp.degree + if ep is not None: + if ep.degree < 1 or dp_world_size % ep.degree != 0: + raise OLMoConfigurationError( + f"{ep.__class__.__name__}.degree must be at least 1 and divide into the world size" + ) + dp_world_size //= ep.degree # Build up mesh dimensions. names: List[str] = [] @@ -129,10 +144,13 @@ def build_device_mesh( names.append(MeshDimName.dp) dims.append(dp_world_size) - # And lastly tensor parallel. + # And lastly tensor/expert parallel. if tp is not None: names.append(MeshDimName.tp) dims.append(tp.degree) + if ep is not None: + names.append(MeshDimName.ep) + dims.append(ep.degree) log.info(f"Building {len(dims)}-D device mesh with dimensions:") for i, (name, dim) in enumerate(zip(names, dims)): @@ -225,6 +243,27 @@ def get_tp_mesh( return None +def get_ep_mesh( + device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.ep +) -> Optional[DeviceMesh]: + """ + Get the expert parallel sub-mesh associated with a ``DeviceMesh`` that was potentially + created from :func:`build_device_mesh()`. + + :param dim_name: The name of the target mesh dimension. + """ + if device_mesh is None: + return None + + if device_mesh.mesh_dim_names is None: + raise RuntimeError("could not determine expert parallel sub-mesh without dimension names") + + if dim_name in device_mesh.mesh_dim_names: + return device_mesh[dim_name] + else: + return None + + def get_pp_mesh( device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.pp ) -> Optional[DeviceMesh]: diff --git a/src/olmo_core/distributed/parallel/expert_parallel.py b/src/olmo_core/distributed/parallel/expert_parallel.py new file mode 100644 index 000000000..2f27786db --- /dev/null +++ b/src/olmo_core/distributed/parallel/expert_parallel.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +from olmo_core.config import Config + + +@dataclass +class ExpertParallelConfig(Config): + """ + Configuration class for expert parallelism (EP). + """ + + degree: int + """ + The EP degree. + """ diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 7d570c66b..2f3f9b007 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -19,10 +19,12 @@ from olmo_core.distributed.parallel import ( DataParallelConfig, DataParallelType, + ExpertParallelConfig, TensorParallelConfig, build_device_mesh, get_dp_mesh, get_dp_process_group, + get_ep_mesh, get_tp_mesh, ) from olmo_core.distributed.utils import get_local_tensor, get_world_size @@ -71,6 +73,13 @@ class TransformerTensorParallelConfig(TensorParallelConfig): """ +@dataclass +class TransformerExpertParallelConfig(ExpertParallelConfig): + """ + Transformer-specific expert parallel config. + """ + + @beta_feature @dataclass class TransformerActivationCheckpointingConfig(Config): @@ -135,6 +144,7 @@ class TransformerTrainModuleConfig(Config): float8_config: Optional[Float8Config] = None dp_config: Optional[TransformerDataParallelConfig] = None tp_config: Optional[TransformerTensorParallelConfig] = None + ep_config: Optional[TransformerExpertParallelConfig] = None ac_config: Optional[TransformerActivationCheckpointingConfig] = None # Loss function settings. @@ -236,6 +246,7 @@ def __init__( float8_config: Optional[Float8Config] = None, dp_config: Optional[TransformerDataParallelConfig] = None, tp_config: Optional[TransformerTensorParallelConfig] = None, + ep_config: Optional[TransformerExpertParallelConfig] = None, ac_config: Optional[TransformerActivationCheckpointingConfig] = None, compile_loss: bool = False, fused_loss: bool = False, @@ -289,7 +300,9 @@ def __init__( ) log.info("Swapped linear layers to Float8 linear layers") - # Maybe apply tensor parallelism. + # Maybe apply tensor/expert parallelism. + if tp_config is not None and ep_config is not None: + raise NotImplementedError("TP + EP is not implemented yet") if tp_config is not None: tp_mesh = get_tp_mesh(self.world_mesh) assert tp_mesh is not None @@ -302,6 +315,13 @@ def __init__( log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" ) + if ep_config is not None: + if not self.model.is_moe: + raise OLMoConfigurationError("Expert parallelism is only valid for MoE models") + ep_mesh = get_ep_mesh(self.world_mesh) + assert ep_mesh is not None + cast(MoETransformer, self.model).apply_ep(ep_mesh) + log.info("Applied expert parallelism to the model") # Maybe apply activation checkpointing. if ac_config is not None: From 70a12756efc7968081b2996e6d90de3add3651c1 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 09:53:32 -0800 Subject: [PATCH 037/230] test size of experts --- src/test/nn/moe/mlp_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index 91e7eb603..e1de29891 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -2,6 +2,7 @@ import torch.distributed as dist from torch.distributed.tensor import init_device_mesh +from olmo_core.distributed.utils import get_local_tensor from olmo_core.nn.moe.mlp import MoEMLP from olmo_core.utils import get_default_device @@ -33,6 +34,7 @@ def run_mlp_with_expert_parallelism(): ) mlp.apply_ep(ep_mesh) mlp.to_empty(device=get_default_device()) + assert get_local_tensor(mlp.w1).shape == (2, 256, 128) x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) tokens_per_expert = torch.tensor([3, 2], device="cuda") From 953049c4c1c03fc64adb662bb87c9ac3de33800a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 09:54:32 -0800 Subject: [PATCH 038/230] don't require grouped gemm for MoEMLP test --- src/test/nn/moe/mlp_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index e1de29891..c7d23ece2 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -7,11 +7,10 @@ from olmo_core.utils import get_default_device from ...distributed.utils import requires_multi_gpu, run_distributed_test -from ...utils import requires_gpu, requires_grouped_gemm +from ...utils import requires_gpu @requires_gpu -@requires_grouped_gemm def test_mlp(): mlp = MoEMLP( d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 @@ -44,6 +43,5 @@ def run_mlp_with_expert_parallelism(): @requires_multi_gpu -@requires_grouped_gemm def test_mlp_with_expert_parallelism(): run_distributed_test(run_mlp_with_expert_parallelism, backend="nccl", start_method="spawn") From 47099406db93d289878a629b3fcf4cdcbf5bb0a2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 09:58:12 -0800 Subject: [PATCH 039/230] move losses to their own module --- src/olmo_core/nn/moe/loss.py | 97 ++++++++++++++++++++++++++++++++++++ src/olmo_core/nn/moe/moe.py | 94 +--------------------------------- 2 files changed, 99 insertions(+), 92 deletions(-) create mode 100644 src/olmo_core/nn/moe/loss.py diff --git a/src/olmo_core/nn/moe/loss.py b/src/olmo_core/nn/moe/loss.py new file mode 100644 index 000000000..6b622b1e7 --- /dev/null +++ b/src/olmo_core/nn/moe/loss.py @@ -0,0 +1,97 @@ +from abc import ABCMeta, abstractmethod +from typing import Dict, Optional, Union + +import torch + +__all__ = ["MoELoss", "MoELoadBalancingLoss", "MoERouterZLoss"] + + +class MoELoss(metaclass=ABCMeta): + @abstractmethod + def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): + raise NotImplementedError + + @abstractmethod + def compute( + self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs + ) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def reset(self): + raise NotImplementedError + + +class MoELoadBalancingLoss(MoELoss): + """ + Implements the load balancing loss from Switch Transformers. + """ + + def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int, top_k: int): + self.loss_weight = loss_weight + self.num_layers = num_layers + self.num_experts = num_experts + self.top_k = top_k + self.loss: Optional[torch.Tensor] = None + + def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): + del kwargs + # shape: (N, num_experts) + expert_scores = expert_logits.softmax(dim=-1) + # shape: (num_experts,) + expert_scores = expert_scores.mean(dim=0) + loss = torch.dot(batch_size_per_expert.type_as(expert_scores), expert_scores) + if self.loss is None: + self.loss = loss + else: + self.loss += loss + + def compute( + self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs + ) -> Dict[str, torch.Tensor]: + del kwargs + if self.loss is None: + raise RuntimeError( + f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" + ) + scale = (self.num_experts * self.loss_weight) / (self.num_layers * total_bz * self.top_k) + lb_loss = scale * self.loss + if reset: + self.reset() + return {"load balancing loss": lb_loss} + + def reset(self): + self.loss = None + + +class MoERouterZLoss(MoELoss): + def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int): + self.loss_weight = loss_weight + self.num_layers = num_layers + self.num_experts = num_experts + self.loss: Optional[torch.Tensor] = None + + def update(self, expert_logits: torch.Tensor, **kwargs): + del kwargs + loss = torch.logsumexp(expert_logits, dim=-1).square().sum() + if self.loss is None: + self.loss = loss + else: + self.loss += loss + + def compute( + self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs + ) -> Dict[str, torch.Tensor]: + del kwargs + if self.loss is None: + raise RuntimeError( + f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" + ) + scale = self.loss_weight / (self.num_layers * total_bz) + lb_loss = scale * self.loss + if reset: + self.reset() + return {"router Z loss": lb_loss} + + def reset(self): + self.loss = None diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 6c43fe7a1..47c74dcb6 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -1,4 +1,4 @@ -from abc import ABCMeta, abstractmethod +from abc import abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Optional, Union @@ -8,6 +8,7 @@ from ...config import Config, StrEnum from ...exceptions import OLMoConfigurationError +from .loss import MoELoadBalancingLoss, MoELoss, MoERouterZLoss from .mlp import MoEMLP, MoEMLPConfig from .parallel_mlp import ParallelDroplessMLP, ParallelMLP from .router import MoERouterConfig @@ -78,97 +79,6 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " ) from e -class MoELoss(metaclass=ABCMeta): - @abstractmethod - def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): - raise NotImplementedError - - @abstractmethod - def compute( - self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs - ) -> Dict[str, torch.Tensor]: - raise NotImplementedError - - @abstractmethod - def reset(self): - raise NotImplementedError - - -class MoELoadBalancingLoss(MoELoss): - """ - Implements the load balancing loss from Switch Transformers. - """ - - def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int, top_k: int): - self.loss_weight = loss_weight - self.num_layers = num_layers - self.num_experts = num_experts - self.top_k = top_k - self.loss: Optional[torch.Tensor] = None - - def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): - del kwargs - # shape: (N, num_experts) - expert_scores = expert_logits.softmax(dim=-1) - # shape: (num_experts,) - expert_scores = expert_scores.mean(dim=0) - loss = torch.dot(batch_size_per_expert.type_as(expert_scores), expert_scores) - if self.loss is None: - self.loss = loss - else: - self.loss += loss - - def compute( - self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs - ) -> Dict[str, torch.Tensor]: - del kwargs - if self.loss is None: - raise RuntimeError( - f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" - ) - scale = (self.num_experts * self.loss_weight) / (self.num_layers * total_bz * self.top_k) - lb_loss = scale * self.loss - if reset: - self.reset() - return {"load balancing loss": lb_loss} - - def reset(self): - self.loss = None - - -class MoERouterZLoss(MoELoss): - def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int): - self.loss_weight = loss_weight - self.num_layers = num_layers - self.num_experts = num_experts - self.loss: Optional[torch.Tensor] = None - - def update(self, expert_logits: torch.Tensor, **kwargs): - del kwargs - loss = torch.logsumexp(expert_logits, dim=-1).square().sum() - if self.loss is None: - self.loss = loss - else: - self.loss += loss - - def compute( - self, total_bz: Union[int, torch.Tensor], reset: bool = True, **kwargs - ) -> Dict[str, torch.Tensor]: - del kwargs - if self.loss is None: - raise RuntimeError( - f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" - ) - scale = self.loss_weight / (self.num_layers * total_bz) - lb_loss = scale * self.loss - if reset: - self.reset() - return {"router Z loss": lb_loss} - - def reset(self): - self.loss = None - - class MoEBase(nn.Module): """ Base class for MoE implementations. From f9f79b6bf8b938aa54af11711a63b5464fce502e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:20:37 -0800 Subject: [PATCH 040/230] remove megablocks from build --- Makefile | 3 --- src/Dockerfile | 5 +---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/Makefile b/Makefile index a0bb002f8..83dc65259 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,6 @@ TORCH_NIGHTLY_VERSION = "2.6.0.dev20241209" TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .) TORCHAO_VERSION = "0.6.1" GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" -MEGABLOCKS_VERSION = "megablocks @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl VERSION = $(shell python src/olmo_core/version.py) @@ -57,7 +56,6 @@ stable-image : --build-arg TORCH_VERSION=$(TORCH_VERSION) \ --build-arg FLASH_ATTN_WHEEL=$(FLASH_ATTN_WHEEL) \ --build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \ - --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --target stable \ --progress plain \ @@ -73,7 +71,6 @@ nightly-image : --build-arg TORCH_VERSION=$(TORCH_VERSION) \ --build-arg FLASH_ATTN_WHEEL=$(FLASH_ATTN_WHEEL) \ --build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \ - --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --build-arg TORCH_NIGHTLY_VERSION=$(TORCH_NIGHTLY_VERSION) \ --target nightly \ diff --git a/src/Dockerfile b/src/Dockerfile index c3ab82b22..0ca44b8dc 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -26,13 +26,10 @@ RUN pip install --upgrade --no-cache-dir pip wheel packaging "setuptools<70.0.0" # Build megablocks, grouped-gemm, stanford-stk ENV TORCH_CUDA_ARCH_LIST="8.0 9.0" -ENV GROUPED_GEMM_CUTLASS="1" +# ENV GROUPED_GEMM_CUTLASS="1" ARG GROUPED_GEMM_VERSION="grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" RUN pip wheel --no-build-isolation --no-cache-dir "${GROUPED_GEMM_VERSION}" -ARG MEGABLOCKS_VERSION="megablocks @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" -RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" - # Build flash-attn. ARG FLASH_ATTN_WHEEL=https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl RUN wget ${FLASH_ATTN_WHEEL} From 674850dffb952d49599f0394d2ffe3050acdde34 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:28:03 -0800 Subject: [PATCH 041/230] update build deps --- Makefile | 6 +++--- src/Dockerfile | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 83dc65259..44bdeae53 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,14 @@ CUDA_VERSION = "12.4" TORCH_CUDA_VERSION = $(shell echo $(CUDA_VERSION) | tr -d .) -TORCH_VERSION = "2.5.1" +TORCH_VERSION = "2.6.0" TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .) # NOTE: when upgrading the nightly version you also need to upgrade the torch version specification # in 'pyproject.toml' to include that nightly version. TORCH_NIGHTLY_VERSION = "2.6.0.dev20241209" TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .) -TORCHAO_VERSION = "0.6.1" +TORCHAO_VERSION = "0.8.0" GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" -FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl +FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl VERSION = $(shell python src/olmo_core/version.py) VERSION_SHORT = $(shell python src/olmo_core/version.py short) diff --git a/src/Dockerfile b/src/Dockerfile index 0ca44b8dc..f9c212b58 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -1,7 +1,7 @@ # NOTE: make sure CUDA_VERSION and TORCH_CUDA_VERSION always match, except for punctuation ARG CUDA_VERSION="12.4" ARG TORCH_CUDA_VERSION="124" -ARG TORCH_VERSION="2.5.1" +ARG TORCH_VERSION="2.6.0 ######################################################################### # Build image @@ -24,20 +24,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Install/upgrade Python build dependencies. RUN pip install --upgrade --no-cache-dir pip wheel packaging "setuptools<70.0.0" ninja -# Build megablocks, grouped-gemm, stanford-stk +# Build grouped-gemm ENV TORCH_CUDA_ARCH_LIST="8.0 9.0" # ENV GROUPED_GEMM_CUTLASS="1" ARG GROUPED_GEMM_VERSION="grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" RUN pip wheel --no-build-isolation --no-cache-dir "${GROUPED_GEMM_VERSION}" # Build flash-attn. -ARG FLASH_ATTN_WHEEL=https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl +ARG FLASH_ATTN_WHEEL=https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl RUN wget ${FLASH_ATTN_WHEEL} # Only keep the target wheels and dependencies with CUDA extensions. RUN echo "Built wheels:" \ && ls -lh . \ - && ls -1 | grep -Ev 'megablocks|grouped_gemm|stanford_stk|flash_attn' | xargs rm \ + && ls -1 | grep -Ev 'grouped_gemm|flash_attn' | xargs rm \ && echo "Final wheels:" \ && ls -lh . @@ -73,7 +73,7 @@ RUN pip install --upgrade --no-cache-dir pip wheel packaging # Install torchao. ARG TORCH_CUDA_VERSION -ARG TORCHAO_VERSION="0.6.1" +ARG TORCHAO_VERSION="0.8.0" RUN pip install --no-cache-dir \ --extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \ torchao==${TORCHAO_VERSION} From 36ab5b13475ed9d47e5946b8cd43832ebffcdbaf Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:32:36 -0800 Subject: [PATCH 042/230] fix --- src/Dockerfile | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Dockerfile b/src/Dockerfile index f9c212b58..7588095fa 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -35,11 +35,10 @@ ARG FLASH_ATTN_WHEEL=https://github.com/Dao-AILab/flash-attention/releases/downl RUN wget ${FLASH_ATTN_WHEEL} # Only keep the target wheels and dependencies with CUDA extensions. -RUN echo "Built wheels:" \ - && ls -lh . \ - && ls -1 | grep -Ev 'grouped_gemm|flash_attn' | xargs rm \ - && echo "Final wheels:" \ - && ls -lh . +RUN echo "Built wheels:" && ls -lh . + # && ls -1 | grep -Ev 'grouped_gemm|flash_attn' | xargs rm \ + # && echo "Final wheels:" \ + # && ls -lh . ######################################################################### # Stable image From b94fa692598d37e2328d1f787112e543201987fd Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:41:07 -0800 Subject: [PATCH 043/230] update stable image --- src/olmo_core/launch/beaker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 0eaae09cc..047a9f3d8 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -57,7 +57,7 @@ class OLMoCoreBeakerImage(StrEnum): includes *versioned* images that are published with each release of the OLMo-core package. """ - stable = "olmo-core-tch251cu124" + stable = "olmo-core-tch260cu124" """ Built with the latest compatible stable version of PyTorch. """ From 0479f74990fb6800c3763b74843d4a0cb0c418c1 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:42:35 -0800 Subject: [PATCH 044/230] update nightly build --- Makefile | 2 +- src/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 44bdeae53..7785e8310 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ TORCH_VERSION = "2.6.0" TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .) # NOTE: when upgrading the nightly version you also need to upgrade the torch version specification # in 'pyproject.toml' to include that nightly version. -TORCH_NIGHTLY_VERSION = "2.6.0.dev20241209" +TORCH_NIGHTLY_VERSION = "2.7.0.dev20250202" TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .) TORCHAO_VERSION = "0.8.0" GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" diff --git a/src/Dockerfile b/src/Dockerfile index 7588095fa..2801e29f2 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -99,7 +99,7 @@ WORKDIR /app/olmo-core FROM stable as nightly ARG TORCH_CUDA_VERSION -ARG TORCH_NIGHTLY_VERSION="2.6.0.dev20241209" +ARG TORCH_NIGHTLY_VERSION="2.7.0.dev20250202" RUN pip install --no-cache-dir --pre \ --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} \ torch==${TORCH_NIGHTLY_VERSION}+cu${TORCH_CUDA_VERSION} From 59867e813e3a271bf794890deed57f591969b694 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:47:55 -0800 Subject: [PATCH 045/230] pin grouped gemm to commit --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 7785e8310..5cb065b97 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .) TORCH_NIGHTLY_VERSION = "2.7.0.dev20250202" TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .) TORCHAO_VERSION = "0.8.0" -GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" +GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@ebeae0bb3ded459886309b2a30410deb16937af4" FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl VERSION = $(shell python src/olmo_core/version.py) From 3a7bc845260e9a0ade23d05839e57c3be6b38739 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:50:09 -0800 Subject: [PATCH 046/230] build with CUTLASS again --- Makefile | 2 +- src/Dockerfile | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 5cb065b97..7785e8310 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .) TORCH_NIGHTLY_VERSION = "2.7.0.dev20250202" TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .) TORCHAO_VERSION = "0.8.0" -GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@ebeae0bb3ded459886309b2a30410deb16937af4" +GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl VERSION = $(shell python src/olmo_core/version.py) diff --git a/src/Dockerfile b/src/Dockerfile index 2801e29f2..8bea146f2 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -24,9 +24,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Install/upgrade Python build dependencies. RUN pip install --upgrade --no-cache-dir pip wheel packaging "setuptools<70.0.0" ninja -# Build grouped-gemm +# Build grouped-gemm. +# NOTE: right now we need to build with CUTLASS so we can pass batch sizes on GPU. +# See https://github.com/tgale96/grouped_gemm/pull/21 ENV TORCH_CUDA_ARCH_LIST="8.0 9.0" -# ENV GROUPED_GEMM_CUTLASS="1" +ENV GROUPED_GEMM_CUTLASS="1" ARG GROUPED_GEMM_VERSION="grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" RUN pip wheel --no-build-isolation --no-cache-dir "${GROUPED_GEMM_VERSION}" From c54d882d93c9fe8b069efe44ffde1c320e02fd66 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 10:57:51 -0800 Subject: [PATCH 047/230] update images used --- .github/workflows/main.yml | 6 +++--- src/olmo_core/internal/common.py | 2 +- src/olmo_core/launch/beaker.py | 2 +- src/scripts/train/all_reduce_bench.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8014d5cfe..683b4aafd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -111,7 +111,7 @@ jobs: matrix: task: - name: Test (GPU) - image: olmo-core-tch251cu124 + image: olmo-core-tch260cu124 gpus: 2 run: | pytest -v --color=yes --durations=3 -m gpu \ @@ -120,14 +120,14 @@ jobs: src/test/ - name: Test checkpoint (GPU) - image: olmo-core-tch251cu124 + image: olmo-core-tch260cu124 gpus: 2 run: | pytest -v --color=yes --durations=3 -m gpu \ src/test/distributed/checkpoint* - name: Test MoE (GPU) - image: olmo-core-tch251cu124 + image: olmo-core-tch260cu124 gpus: 1 run: | pytest -v --color=yes --durations=3 -m gpu \ diff --git a/src/olmo_core/internal/common.py b/src/olmo_core/internal/common.py index dfef00087..9247c5ae2 100644 --- a/src/olmo_core/internal/common.py +++ b/src/olmo_core/internal/common.py @@ -80,7 +80,7 @@ def build_launch_config( workspace=workspace, clusters=[cluster], weka_buckets=weka_buckets, - beaker_image=OLMoCoreBeakerImage.nightly, # some features require nightly at the moment + beaker_image=OLMoCoreBeakerImage.stable, num_nodes=1, num_gpus=8, shared_filesystem=not is_url(root_dir), diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 047a9f3d8..f12b52d8d 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -62,7 +62,7 @@ class OLMoCoreBeakerImage(StrEnum): Built with the latest compatible stable version of PyTorch. """ - nightly = "olmo-core-tch260dev20241209cu124" + nightly = "olmo-core-tch270dev20250202cu124" """ Built with the latest compatible nightly version of PyTorch. """ diff --git a/src/scripts/train/all_reduce_bench.py b/src/scripts/train/all_reduce_bench.py index b00ce60ae..78eeebe4e 100644 --- a/src/scripts/train/all_reduce_bench.py +++ b/src/scripts/train/all_reduce_bench.py @@ -102,7 +102,7 @@ def build_config(script: str, run_name: str, cluster: str, overrides: List[str]) task_name="benchmark", workspace="ai2/OLMo-core", clusters=[cluster], - beaker_image=OLMoCoreBeakerImage.nightly, # some features require nightly at the moment + beaker_image=OLMoCoreBeakerImage.stable, num_nodes=1, num_gpus=8, allow_dirty=False, From 0c7baed3beec0289255e0f808e9918e489646aaa Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 16:01:25 -0800 Subject: [PATCH 048/230] fix expert parallelism, implement sequence parallelism --- .../distributed/parallel/__init__.py | 106 ++++++++++++------ .../distributed/parallel/tensor_parallel.py | 2 +- src/olmo_core/nn/moe/mlp.py | 12 +- src/olmo_core/nn/moe/moe.py | 32 +++++- src/olmo_core/nn/moe/parallel_mlp.py | 4 +- src/olmo_core/nn/transformer/block.py | 59 ++++------ src/olmo_core/nn/transformer/model.py | 15 ++- .../train/train_module/transformer.py | 5 +- src/test/nn/moe/mlp_test.py | 7 +- 9 files changed, 146 insertions(+), 96 deletions(-) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 2ce79b11d..2aae0eea2 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -19,12 +19,13 @@ __all__ = [ "build_device_mesh", + "build_expert_parallel_mesh", "MeshDimName", "get_dp_mesh", "get_tp_mesh", - "get_ep_mesh", "get_pp_mesh", "get_dp_process_group", + "get_num_ep_shards", "DataParallelType", "DataParallelConfig", "DPMeshDimName", @@ -58,9 +59,14 @@ class MeshDimName(StrEnum): The DP dimension over which the model is sharded. """ - ep = "ep" + ep_replicate = "ep_replicate" """ - Expert parallel (EP). + The EP dimension over which the experts are replicated. + """ + + ep_shard = "ep_shard" + """ + The EP dimension over which the experts are sharded. """ tp = "tp" @@ -77,7 +83,6 @@ class MeshDimName(StrEnum): def build_device_mesh( *, dp: Optional[DataParallelConfig] = None, - ep: Optional[ExpertParallelConfig] = None, tp: Optional[TensorParallelConfig] = None, pp: Optional[PipelineParallelConfig] = None, device_type: Optional[str] = None, @@ -88,8 +93,12 @@ def build_device_mesh( .. important:: A data parallel config is required if any other parallel config is set. + + .. seealso:: + Expert parallel device meshes need to be created separately with + :func:`build_expert_parallel_mesh`. """ - if ep is None and pp is None and tp is None and dp is None: + if pp is None and tp is None and dp is None: return None if dp is None: raise OLMoConfigurationError( @@ -112,12 +121,6 @@ def build_device_mesh( f"{tp.__class__.__name__}.degree must be at least 1 and divide into the world size" ) dp_world_size //= tp.degree - if ep is not None: - if ep.degree < 1 or dp_world_size % ep.degree != 0: - raise OLMoConfigurationError( - f"{ep.__class__.__name__}.degree must be at least 1 and divide into the world size" - ) - dp_world_size //= ep.degree # Build up mesh dimensions. names: List[str] = [] @@ -148,9 +151,6 @@ def build_device_mesh( if tp is not None: names.append(MeshDimName.tp) dims.append(tp.degree) - if ep is not None: - names.append(MeshDimName.ep) - dims.append(ep.degree) log.info(f"Building {len(dims)}-D device mesh with dimensions:") for i, (name, dim) in enumerate(zip(names, dims)): @@ -162,6 +162,40 @@ def build_device_mesh( return mesh +def build_expert_parallel_mesh( + ep_config: ExpertParallelConfig, device_type: Optional[str] = None +) -> Optional[DeviceMesh]: + """ + Build a device mesh for expert parallelism. + """ + device_type = device_type or get_default_device().type + world_size = get_world_size() + + if ep_config.degree == world_size: + return None + + # Build up mesh dimensions. + names: List[str] = [] + dims: List[int] = [] + + if world_size % ep_config.degree != 0: + raise OLMoConfigurationError( + f"Expert parallelism requires world size ({world_size}) to " + f"be divisible by 'degree' ({ep_config.degree})" + ) + names.append(MeshDimName.ep_replicate) + dims.append(world_size // ep_config.degree) + + names.append(MeshDimName.ep_shard) + dims.append(ep_config.degree) + + log.info(f"Building {len(dims)}-D device mesh with dimensions:") + for i, (name, dim) in enumerate(zip(names, dims)): + log.info(f" > dimension {i}, size={dim}, name={name}") + + return init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) + + def get_dp_mesh( device_mesh: Optional[DeviceMesh] = None, *, @@ -174,8 +208,8 @@ def get_dp_mesh( created from :func:`build_device_mesh()`. :param dim_name: The name of the base data parallel mesh dimension. - :param dim_name: The name of the replica-specific data parallel mesh dimension. - :param dim_name: The name of the shard-specific data parallel mesh dimension. + :param replicate_dim_name: The name of the replica-specific data parallel mesh dimension. + :param shard_dim_name: The name of the shard-specific data parallel mesh dimension. """ if device_mesh is None: return None @@ -243,11 +277,11 @@ def get_tp_mesh( return None -def get_ep_mesh( - device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.ep +def get_pp_mesh( + device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.pp ) -> Optional[DeviceMesh]: """ - Get the expert parallel sub-mesh associated with a ``DeviceMesh`` that was potentially + Get the tensor parallel sub-mesh associated with a ``DeviceMesh`` that was potentially created from :func:`build_device_mesh()`. :param dim_name: The name of the target mesh dimension. @@ -256,7 +290,7 @@ def get_ep_mesh( return None if device_mesh.mesh_dim_names is None: - raise RuntimeError("could not determine expert parallel sub-mesh without dimension names") + raise RuntimeError("could not determine pipeline parallel sub-mesh without dimension names") if dim_name in device_mesh.mesh_dim_names: return device_mesh[dim_name] @@ -264,22 +298,22 @@ def get_ep_mesh( return None -def get_pp_mesh( - device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.pp -) -> Optional[DeviceMesh]: +def get_num_ep_shards( + ep_mesh: Optional[DeviceMesh] = None, *, shard_dim_name: Optional[str] = None +) -> int: """ - Get the tensor parallel sub-mesh associated with a ``DeviceMesh`` that was potentially - created from :func:`build_device_mesh()`. - - :param dim_name: The name of the target mesh dimension. + Get the number of expert parallel shards. """ - if device_mesh is None: - return None - - if device_mesh.mesh_dim_names is None: - raise RuntimeError("could not determine pipeline parallel sub-mesh without dimension names") - - if dim_name in device_mesh.mesh_dim_names: - return device_mesh[dim_name] + if ep_mesh is None: + return get_world_size() + + if ep_mesh.mesh_dim_names is None: + raise RuntimeError("could not determine expert parallel shard sub-mesh") + elif shard_dim_name is not None: + return ep_mesh[shard_dim_name].size() + elif MeshDimName.ep_shard in ep_mesh.mesh_dim_names: + return ep_mesh[MeshDimName.ep_shard].size() + elif MeshDimName.tp in ep_mesh.mesh_dim_names: + return ep_mesh[MeshDimName.tp].size() else: - return None + raise RuntimeError("could not determine expert parallel shard sub-mesh") diff --git a/src/olmo_core/distributed/parallel/tensor_parallel.py b/src/olmo_core/distributed/parallel/tensor_parallel.py index 14e80f122..0dbafd818 100644 --- a/src/olmo_core/distributed/parallel/tensor_parallel.py +++ b/src/olmo_core/distributed/parallel/tensor_parallel.py @@ -49,7 +49,7 @@ def __init__( output_layouts: Optional[Placement] = None, ): super().__init__(sequence_dim=sequence_dim, use_local_output=use_local_output) - self.output_layouts = (output_layouts or Shard(1),) + self.output_layouts = (output_layouts or Shard(sequence_dim),) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 59ee5e205..0efbeb16a 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -9,6 +9,7 @@ from torch.distributed.tensor import Shard, distribute_tensor from ...config import Config, DType, StrEnum +from ...distributed.parallel import get_num_ep_shards from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError @@ -199,17 +200,18 @@ def forward(self, x: torch.Tensor, batch_size_per_expert: torch.Tensor) -> torch x1 = F.silu(x1) * x2 return self.gmm(x1, w2, batch_size_per_expert) - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): """ Apply expert parallelism. """ - if self.num_experts % ep_mesh.size() != 0: + num_shards = get_num_ep_shards(ep_mesh) + if self.num_experts % num_shards != 0: raise OLMoConfigurationError( - f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel degree ({ep_mesh.size()})." + f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel shard degree ({num_shards})." ) - self.experts_per_rank = self.num_experts // ep_mesh.size() - self.gradient_scale = 1.0 / ep_mesh.size() + self.experts_per_rank = self.num_experts // num_shards + self.gradient_scale = 1.0 / num_shards self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 47c74dcb6..4be1b8af3 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -5,9 +5,13 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh +from torch.distributed.tensor import Placement, Replicate, Shard +from torch.distributed.tensor.parallel import PrepareModuleOutput, parallelize_module + +from olmo_core.config import Config, StrEnum +from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel +from olmo_core.exceptions import OLMoConfigurationError -from ...config import Config, StrEnum -from ...exceptions import OLMoConfigurationError from .loss import MoELoadBalancingLoss, MoELoss, MoERouterZLoss from .mlp import MoEMLP, MoEMLPConfig from .parallel_mlp import ParallelDroplessMLP, ParallelMLP @@ -164,12 +168,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): """ Apply expert parallelism. """ self.experts.apply_ep(ep_mesh) + def apply_tp( + self, + tp_mesh: Optional[DeviceMesh] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + parallelize_module( + self.router, + device_mesh=tp_mesh, + parallelize_plan=SequenceParallel(use_local_output=True), + ) + self.experts.apply_ep(tp_mesh) + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleOutput( + output_layouts=(Shard(1),), + desired_output_layouts=(output_layouts or Replicate(),), + use_local_output=use_local_output, + ), + ) + class DroplessMoE(MoEBase): """ diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index dd5cd6533..a74fa2412 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -49,14 +49,14 @@ def ep_world_size(self) -> int: else: return 1 - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): """ Apply expert parallelism. """ self.mlp.apply_ep(ep_mesh) self._expert_parallel_enabled = True self._ep_mesh = ep_mesh - self._ep_pg = ep_mesh.get_group() + self._ep_pg = None if ep_mesh is None else ep_mesh.get_group() def forward( self, diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 000ed103e..643740910 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -251,38 +251,6 @@ def forward( ) return h + self.dropout(self.feed_forward_norm(self.feed_forward(h))) - @property - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - return Replicate() - - def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): - _, _, prepare_module_input = get_tp_wrappers(float8_enabled=float8_enabled) - - plan = { - "attention": prepare_module_input( - input_layouts=(Replicate(),), - desired_input_layouts=(Replicate(),), - ), - "attention_norm": SequenceParallel(output_layouts=Replicate(), use_local_output=True), - "feed_forward": prepare_module_input( - input_layouts=(Replicate(),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward_norm": SequenceParallel( - output_layouts=Replicate(), use_local_output=True - ), - } - if isinstance(self.dropout, nn.Dropout): - plan["dropout"] = SequenceParallel() - parallelize_module( - module=self, - device_mesh=tp_mesh, - parallelize_plan=plan, - ) - - self.attention.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) - self.feed_forward.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) - @beta_feature class NormalizedTransformerBlock(TransformerBlockBase): @@ -431,18 +399,33 @@ def forward( ) return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): self.feed_forward_moe.apply_ep(ep_mesh) + def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: + return Shard(1) + def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): - del tp_mesh, float8_enabled + _, _, prepare_module_input = get_tp_wrappers(float8_enabled=float8_enabled) - raise NotImplementedError( - f"TP is not implemented yet for the '{self.__class__.__name__}' variant" + plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward_norm": SequenceParallel(), + } + if isinstance(self.dropout, nn.Dropout): + plan["dropout"] = SequenceParallel() + parallelize_module( + module=self, + device_mesh=tp_mesh, + parallelize_plan=plan, ) - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - raise NotImplementedError + self.attention.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) + self.feed_forward_moe.apply_tp(tp_mesh, output_layouts=Shard(1), use_local_output=False) class MoEReorderedNormTransformerBlock(MoETransformerBlock): diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 8f61edc1d..a2e19476e 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -277,16 +277,13 @@ def apply_tp( parallelize_module, ) - emb_output_layout = cast(TransformerBlockBase, self.blocks["0"]).tp_input_layouts parallelize_module( module=self, device_mesh=tp_mesh, parallelize_plan={ "embeddings": RowwiseParallel( input_layouts=Replicate(), - output_layouts=emb_output_layout[0] - if isinstance(emb_output_layout, tuple) - else emb_output_layout, + use_local_output=False, ), "lm_head": PrepareModuleInput( # block output layouts are same as block input layouts @@ -303,7 +300,13 @@ def apply_tp( # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 for block in self.blocks.values(): - cast(TransformerBlockBase, block).apply_tp(tp_mesh, float8_enabled=float8_enabled) + block = cast(TransformerBlockBase, block) + block.apply_tp(tp_mesh, float8_enabled=float8_enabled) + parallelize_module( + block, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput(desired_input_layouts=block.tp_input_layouts), + ) def apply_activation_checkpointing( self, @@ -654,6 +657,6 @@ def forward( return self.lm_head(h) if self.lm_head is not None else h - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): for block in self.blocks.values(): cast(MoETransformerBlock, block).apply_ep(ep_mesh) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 2f3f9b007..f9ccf6fe9 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -22,9 +22,9 @@ ExpertParallelConfig, TensorParallelConfig, build_device_mesh, + build_expert_parallel_mesh, get_dp_mesh, get_dp_process_group, - get_ep_mesh, get_tp_mesh, ) from olmo_core.distributed.utils import get_local_tensor, get_world_size @@ -318,8 +318,7 @@ def __init__( if ep_config is not None: if not self.model.is_moe: raise OLMoConfigurationError("Expert parallelism is only valid for MoE models") - ep_mesh = get_ep_mesh(self.world_mesh) - assert ep_mesh is not None + ep_mesh = build_expert_parallel_mesh(ep_config) cast(MoETransformer, self.model).apply_ep(ep_mesh) log.info("Applied expert parallelism to the model") diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index c7d23ece2..767617665 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -1,7 +1,10 @@ import torch import torch.distributed as dist -from torch.distributed.tensor import init_device_mesh +from olmo_core.distributed.parallel import ( + ExpertParallelConfig, + build_expert_parallel_mesh, +) from olmo_core.distributed.utils import get_local_tensor from olmo_core.nn.moe.mlp import MoEMLP from olmo_core.utils import get_default_device @@ -22,7 +25,7 @@ def test_mlp(): def run_mlp_with_expert_parallelism(): - ep_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) + ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) mlp = MoEMLP( d_model=128, From dfb118aa1cef535a46ef012263abc201ff5882ac Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 16:04:15 -0800 Subject: [PATCH 049/230] fix test --- src/test/nn/moe/moe_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index b99cd36e7..824589a6d 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -3,9 +3,12 @@ import pytest import torch import torch.distributed as dist -from torch.distributed.tensor import init_device_mesh from olmo_core.config import DType +from olmo_core.distributed.parallel import ( + ExpertParallelConfig, + build_expert_parallel_mesh, +) from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType from olmo_core.utils import get_default_device @@ -59,7 +62,7 @@ def test_moe(moe_type, dtype): def run_moe_with_expert_parallelism(moe_type, dtype): - ep_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) + ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=min(dist.get_world_size(), 2))) d_model = 128 config = MoEConfig( From baef2f21186b703ee1d2aacdcbc469bef7bf67d7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 16:11:55 -0800 Subject: [PATCH 050/230] fix --- src/olmo_core/distributed/parallel/__init__.py | 11 +++-------- src/olmo_core/nn/moe/mlp.py | 2 +- src/olmo_core/nn/moe/moe.py | 4 ++-- src/olmo_core/nn/moe/parallel_mlp.py | 2 +- src/olmo_core/nn/transformer/block.py | 2 +- src/olmo_core/nn/transformer/model.py | 2 +- 6 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 2aae0eea2..d55145d56 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -164,7 +164,7 @@ def build_device_mesh( def build_expert_parallel_mesh( ep_config: ExpertParallelConfig, device_type: Optional[str] = None -) -> Optional[DeviceMesh]: +) -> DeviceMesh: """ Build a device mesh for expert parallelism. """ @@ -172,7 +172,7 @@ def build_expert_parallel_mesh( world_size = get_world_size() if ep_config.degree == world_size: - return None + return init_device_mesh(device_type, (world_size,), mesh_dim_names=(MeshDimName.ep_shard,)) # Build up mesh dimensions. names: List[str] = [] @@ -298,15 +298,10 @@ def get_pp_mesh( return None -def get_num_ep_shards( - ep_mesh: Optional[DeviceMesh] = None, *, shard_dim_name: Optional[str] = None -) -> int: +def get_num_ep_shards(ep_mesh: DeviceMesh, *, shard_dim_name: Optional[str] = None) -> int: """ Get the number of expert parallel shards. """ - if ep_mesh is None: - return get_world_size() - if ep_mesh.mesh_dim_names is None: raise RuntimeError("could not determine expert parallel shard sub-mesh") elif shard_dim_name is not None: diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 0efbeb16a..40797437f 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -200,7 +200,7 @@ def forward(self, x: torch.Tensor, batch_size_per_expert: torch.Tensor) -> torch x1 = F.silu(x1) * x2 return self.gmm(x1, w2, batch_size_per_expert) - def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): + def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 4be1b8af3..2ea3a2672 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -168,7 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out - def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): + def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ @@ -176,7 +176,7 @@ def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): def apply_tp( self, - tp_mesh: Optional[DeviceMesh] = None, + tp_mesh: DeviceMesh, output_layouts: Optional[Placement] = None, use_local_output: bool = True, ): diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index a74fa2412..69b7a1a74 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -49,7 +49,7 @@ def ep_world_size(self) -> int: else: return 1 - def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): + def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 643740910..e06cec995 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -399,7 +399,7 @@ def forward( ) return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) - def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): + def apply_ep(self, ep_mesh: DeviceMesh): self.feed_forward_moe.apply_ep(ep_mesh) def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index a2e19476e..935238d1e 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -657,6 +657,6 @@ def forward( return self.lm_head(h) if self.lm_head is not None else h - def apply_ep(self, ep_mesh: Optional[DeviceMesh] = None): + def apply_ep(self, ep_mesh: DeviceMesh): for block in self.blocks.values(): cast(MoETransformerBlock, block).apply_ep(ep_mesh) From d352a2bd90d3a42ffe35ab0717f08c2b1b558949 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 2 Feb 2025 16:41:50 -0800 Subject: [PATCH 051/230] fix --- .../distributed/parallel/__init__.py | 21 +++---------------- src/olmo_core/nn/moe/mlp.py | 5 +++-- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index d55145d56..9e242d65d 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -25,7 +25,6 @@ "get_tp_mesh", "get_pp_mesh", "get_dp_process_group", - "get_num_ep_shards", "DataParallelType", "DataParallelConfig", "DPMeshDimName", @@ -193,7 +192,9 @@ def build_expert_parallel_mesh( for i, (name, dim) in enumerate(zip(names, dims)): log.info(f" > dimension {i}, size={dim}, name={name}") - return init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) + return init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names))[ + MeshDimName.ep_shard + ] def get_dp_mesh( @@ -296,19 +297,3 @@ def get_pp_mesh( return device_mesh[dim_name] else: return None - - -def get_num_ep_shards(ep_mesh: DeviceMesh, *, shard_dim_name: Optional[str] = None) -> int: - """ - Get the number of expert parallel shards. - """ - if ep_mesh.mesh_dim_names is None: - raise RuntimeError("could not determine expert parallel shard sub-mesh") - elif shard_dim_name is not None: - return ep_mesh[shard_dim_name].size() - elif MeshDimName.ep_shard in ep_mesh.mesh_dim_names: - return ep_mesh[MeshDimName.ep_shard].size() - elif MeshDimName.tp in ep_mesh.mesh_dim_names: - return ep_mesh[MeshDimName.tp].size() - else: - raise RuntimeError("could not determine expert parallel shard sub-mesh") diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 40797437f..2d94555b2 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -9,7 +9,6 @@ from torch.distributed.tensor import Shard, distribute_tensor from ...config import Config, DType, StrEnum -from ...distributed.parallel import get_num_ep_shards from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError @@ -204,7 +203,9 @@ def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ - num_shards = get_num_ep_shards(ep_mesh) + if ep_mesh.ndim > 1: + raise RuntimeError("local expert parallel sub-mesh must be 1-dimensional") + num_shards = ep_mesh.size() if self.num_experts % num_shards != 0: raise OLMoConfigurationError( f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel shard degree ({num_shards})." From 83d28af4c653b46b71bacc3b2cb5794fe75e558f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 11:18:12 -0800 Subject: [PATCH 052/230] Start on regular MoE --- src/olmo_core/nn/moe/kernels.py | 236 +++++++++++++++++++++++++++ src/olmo_core/nn/moe/moe.py | 59 ++++++- src/olmo_core/nn/moe/ops.py | 94 +++++++++++ src/olmo_core/nn/moe/parallel_mlp.py | 184 +++++++++++++++------ 4 files changed, 512 insertions(+), 61 deletions(-) diff --git a/src/olmo_core/nn/moe/kernels.py b/src/olmo_core/nn/moe/kernels.py index 7c5ef9ab9..c11edf826 100644 --- a/src/olmo_core/nn/moe/kernels.py +++ b/src/olmo_core/nn/moe/kernels.py @@ -298,3 +298,239 @@ def scatter_wgrad( top_k: int, ) -> torch.Tensor: return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), + ], + key=["NUM_COLUMNS"], +) +@triton.jit +def _binned_copy( + a, # (tokens, hidden_size), real. + b, # (num_experts, expert_capacity, num_columns), real. + expert_capacity, + indices, # (tokens * top_k), integer. + weights, # (tokens * top_k), real. + bins, # (num_experts), integer. + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) # type: ignore + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + expert_capacity: int, + top_k: int, +) -> torch.Tensor: + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( # type: ignore + x, + out, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( # type: ignore + out, + x, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), + ], + key=["NUM_COLUMNS"], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad( + x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int +) -> torch.Tensor: + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( # type: ignore + x, + grad, + out, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 2ea3a2672..effc26195 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -14,11 +14,11 @@ from .loss import MoELoadBalancingLoss, MoELoss, MoERouterZLoss from .mlp import MoEMLP, MoEMLPConfig -from .parallel_mlp import ParallelDroplessMLP, ParallelMLP +from .parallel_mlp import ParallelDroplessMLP, ParallelMLP, ParallelMLPBase from .router import MoERouterConfig from .shared_mlp import SharedMLPConfig -__all__ = ["MoEBase", "DroplessMoE", "MoEConfig", "MoEType"] +__all__ = ["MoEBase", "MoE", "DroplessMoE", "MoEConfig", "MoEType"] class MoEType(StrEnum): @@ -26,6 +26,11 @@ class MoEType(StrEnum): An enumeration of the different MoE implementations. """ + default = "default" + """ + ➡️ :class:`MoE` + """ + dropless = "dropless" """ ➡️ :class:`DroplessMoE` @@ -34,12 +39,13 @@ class MoEType(StrEnum): @dataclass class MoEConfig(Config): - name: MoEType = MoEType.dropless + name: MoEType = MoEType.dropless # TODO: change to default """ The name of the implementation. """ num_experts: int = 1 hidden_size: int = 256 + capacity_factor: Optional[float] = None router: MoERouterConfig = field(default_factory=MoERouterConfig) mlp: MoEMLPConfig = field(default_factory=MoEMLPConfig) shared_mlp: Optional[SharedMLPConfig] = None @@ -101,11 +107,12 @@ def __init__( init_device: str = "cpu", lb_loss_weight: Optional[float] = None, z_loss_weight: Optional[float] = None, + **kwargs, ): super().__init__() self.router = router.build(d_model, num_experts, init_device=init_device) self.experts = self._init_parallel_mlp( - mlp.build(d_model, num_experts, hidden_size, init_device=init_device) + mlp.build(d_model, num_experts, hidden_size, init_device=init_device), **kwargs ) self.shared_experts = ( None @@ -142,9 +149,8 @@ def reset_losses(self): for loss_fn in self.losses: loss_fn.reset() - @classmethod @abstractmethod - def _init_parallel_mlp(cls, mlp: MoEMLP) -> ParallelMLP: + def _init_parallel_mlp(self, mlp: MoEMLP, **kwargs) -> ParallelMLPBase: raise NotImplementedError def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -197,11 +203,48 @@ def apply_tp( ) +class MoE(MoEBase): + """ + A basic MoE implementation. + """ + + def __init__( + self, + *, + d_model: int, + num_experts: int, + hidden_size: int, + router: MoERouterConfig, + mlp: MoEMLPConfig, + num_layers: int, + shared_mlp: Optional[SharedMLPConfig] = None, + capacity_factor: float = 1.2, + init_device: str = "cpu", + lb_loss_weight: Optional[float] = None, + z_loss_weight: Optional[float] = None, + ): + super().__init__( + d_model=d_model, + num_experts=num_experts, + hidden_size=hidden_size, + router=router, + mlp=mlp, + num_layers=num_layers, + shared_mlp=shared_mlp, + init_device=init_device, + lb_loss_weight=lb_loss_weight, + z_loss_weight=z_loss_weight, + capacity_factor=capacity_factor, + ) + + def _init_parallel_mlp(self, mlp: MoEMLP, *, capacity_factor: float) -> ParallelMLP: # type: ignore[override] + return ParallelMLP(mlp=mlp, capacity_factor=capacity_factor) + + class DroplessMoE(MoEBase): """ A dropless MoE implementation. """ - @classmethod - def _init_parallel_mlp(cls, mlp: MoEMLP) -> ParallelMLP: + def _init_parallel_mlp(self, mlp: MoEMLP) -> ParallelDroplessMLP: # type: ignore[override] return ParallelDroplessMLP(mlp=mlp) diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index 93275d7f3..8d4b0f597 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -151,6 +151,100 @@ def scatter( return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) # type: ignore +class BinnedGatherOp(torch.autograd.Function): + @staticmethod + @autocast_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + from . import kernels + + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @autocast_bwd + def backward(ctx: Any, grad: torch.Tensor): + from . import kernels + + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +def binned_gather( + x: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, bin_size: int, top_k: int +) -> torch.Tensor: + return BinnedGatherOp.apply(x, indices, bins, bin_size, top_k) # type: ignore + + +class BinnedScatterOp(torch.autograd.Function): + @staticmethod + @autocast_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + from . import kernels + + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @autocast_bwd + def backward(ctx: Any, grad: torch.Tensor): + from . import kernels + + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: + return BinnedScatterOp.apply(x, indices, weights, bins, top_k) # type: ignore + + def repeat(x: torch.Tensor, tiling: Union[torch.Size, Tuple[int, ...]]) -> torch.Tensor: if all((t == 1 for t in tiling)): return x diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 69b7a1a74..71e44dabc 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -1,5 +1,6 @@ # Adapted from 'https://github.com/databricks/megablocks/blob/main/megablocks/layers/moe.py' and 'dmoe.py' +from abc import abstractmethod from typing import Optional, Tuple import torch @@ -11,10 +12,10 @@ from . import ops from .mlp import MoEMLP -__all__ = ["ParallelMLP", "ParallelDroplessMLP"] +__all__ = ["ParallelMLPBase", "ParallelMLP", "ParallelDroplessMLP"] -class ParallelMLP(nn.Module): +class ParallelMLPBase(nn.Module): """ Wraps an MoE MLP layer to coordinate the routing and expert parallelism. """ @@ -58,6 +59,33 @@ def apply_ep(self, ep_mesh: DeviceMesh): self._ep_mesh = ep_mesh self._ep_pg = None if ep_mesh is None else ep_mesh.get_group() + def indices_and_bins( + self, expert_indices: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param expert_indices: A 1D tensor. + """ + # Histogram the expert ids to identify the number of + # items/tokens routed to each expert. + # shape: (num_experts,), LongTensor + batch_size_per_expert = torch.histc( + expert_indices, bins=self.num_experts, min=0, max=self.num_experts - 1 + ) + + expert_indices = expert_indices.int() + + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # shape: (N,), (N,) + bin_ids, indices = torch.sort(expert_indices) + + # Calculate the bin bounds for the sorted items/tokens. + # shape: (num_experts,) + bins = torch.empty_like(batch_size_per_expert, dtype=torch.int32) + torch.cumsum(batch_size_per_expert, 0, out=bins) + + return indices.int(), bin_ids, bins, batch_size_per_expert + def forward( self, x: torch.Tensor, @@ -72,21 +100,6 @@ def forward( :returns: The output with the same shape as ``x`` and a tensor with shape ``(num_experts,)`` containing the number of items/tokens routed to each expert. """ - del x, expert_weights, expert_indices - raise NotImplementedError - - -class ParallelDroplessMLP(ParallelMLP): - """ - A dropless implementation of a :class:`ParallelMLP`. - """ - - def forward( - self, - x: torch.Tensor, - expert_weights: torch.Tensor, - expert_indices: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: in_shape = x.size() # Compute the experts. @@ -97,6 +110,7 @@ def forward( return x.view(in_shape), batch_size_per_expert + @abstractmethod def forward_once( self, x: torch.Tensor, @@ -109,6 +123,103 @@ def forward_once( typically equals ``batch_size x seq_len``. :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. """ + raise NotImplementedError + + @abstractmethod + def parallel_forward_once( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param x: The input of shape ``(*, d_model)``. + :param expert_weights: Expert weights of shape ``(N, top_k)``, where ``N`` + typically equals ``batch_size x seq_len``. + :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. + """ + raise NotImplementedError + + +class ParallelMLP(ParallelMLPBase): + def __init__(self, *, mlp: MoEMLP, capacity_factor: float): + super().__init__(mlp=mlp) + self.capacity_factor = capacity_factor + + def expert_capacity(self, top_k: int, num_items: int) -> int: + items_per_expert = top_k * num_items * self.ep_world_size / self.num_experts + return int(self.capacity_factor * items_per_expert) + + def forward_once( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_items, top_k = expert_weights.shape + + # shape: (N * top_k,) + expert_weights = expert_weights.flatten() + # shape: (N * top_k,) + expert_indices = expert_indices.flatten() + + with torch.no_grad(): + indices, _, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) + expert_capacity = self.expert_capacity(top_k, num_items) + + x = self.permute_and_compute( + x, + indices, + expert_weights, + bins, + expert_capacity, + top_k, + ) + return x, batch_size_per_expert + + def parallel_forward_once( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + del x, expert_weights, expert_indices + raise NotImplementedError + + def permute_and_compute( + self, + x: torch.Tensor, + indices: torch.Tensor, + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ) -> torch.Tensor: + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Perform the expert computation. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + +class ParallelDroplessMLP(ParallelMLPBase): + """ + A dropless implementation of a :class:`ParallelMLP`. + + .. warning:: + When expert parallelism is enabled the forward pass involves a host-device sync. + """ + + def forward_once( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: top_k = expert_weights.shape[-1] # shape: (N * top_k,) @@ -117,11 +228,11 @@ def forward_once( expert_indices = expert_indices.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) + indices, bin_ids, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) out = self.permute_and_compute( x, - tokens_per_expert, + batch_size_per_expert, indices, bin_ids, expert_weights, @@ -129,7 +240,7 @@ def forward_once( top_k, ) - return out, tokens_per_expert + return out, batch_size_per_expert def parallel_forward_once( self, @@ -137,12 +248,6 @@ def parallel_forward_once( expert_weights: torch.Tensor, expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - :param x: The input of shape ``(*, d_model)``. - :param expert_weights: Expert weights of shape ``(N, top_k)``, where ``N`` - typically equals ``batch_size x seq_len``. - :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. - """ # NOTE: This function implements the same computation as forward_once # but with expert model parallelism. # @@ -303,33 +408,6 @@ def parallel_forward_once( return x, tokens_per_expert.flatten() - def indices_and_bins( - self, expert_indices: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - :param expert_indices: A 1D tensor. - """ - # Histogram the expert ids to identify the number of - # items/tokens routed to each expert. - # shape: (num_experts,), LongTensor - batch_size_per_expert = torch.histc( - expert_indices, bins=self.num_experts, min=0, max=self.num_experts - 1 - ) - - expert_indices = expert_indices.int() - - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # shape: (N,), (N,) - bin_ids, indices = torch.sort(expert_indices) - - # Calculate the bin bounds for the sorted items/tokens. - # shape: (num_experts,) - bins = torch.empty_like(batch_size_per_expert, dtype=torch.int32) - torch.cumsum(batch_size_per_expert, 0, out=bins) - - return indices.int(), bin_ids, bins, batch_size_per_expert - def permute_and_compute( self, x: torch.Tensor, From 62dfdc4b84c9b5b910e14233c624e39b5b0df03a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:16:10 -0800 Subject: [PATCH 053/230] finish? --- src/olmo_core/nn/moe/mlp.py | 4 +- src/olmo_core/nn/moe/moe.py | 4 +- src/olmo_core/nn/moe/ops.py | 15 ++- src/olmo_core/nn/moe/parallel_mlp.py | 188 ++++++++++++++++++++++----- src/test/nn/moe/moe_test.py | 6 +- 5 files changed, 172 insertions(+), 45 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 2d94555b2..d667db1c3 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -181,9 +181,9 @@ def forward(self, x: torch.Tensor, batch_size_per_expert: torch.Tensor) -> torch """ Compute the expert outputs. - :param x: The input of shape ``(N, d_model)``. + :param x: The input of shape ``(*, d_model)``. :param batch_size_per_expert: Specifies how many items/tokens go to each expert. Should be a - 1-D ``LongTensor`` which sums to ``N``. + 1-D ``LongTensor``. """ # Scale gradients and get local tensors (in case of expert parallelism). # shape (all): (experts_per_rank, hidden_size, d_model) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index effc26195..ad0cae171 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -79,7 +79,9 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " ) try: - if self.name == MoEType.dropless: + if self.name == MoEType.default: + return MoE(**kwargs) + elif self.name == MoEType.dropless: return DroplessMoE(**kwargs) else: raise NotImplementedError(self.name) diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index 8d4b0f597..ab9b4fdf3 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -192,7 +192,7 @@ def forward( ctx: Any, x: torch.Tensor, indices: torch.Tensor, - weights: torch.Tensor, + weights: Optional[torch.Tensor], bins: torch.Tensor, top_k: int, ): @@ -238,7 +238,7 @@ def backward(ctx: Any, grad: torch.Tensor): def binned_scatter( x: torch.Tensor, indices: torch.Tensor, - weights: torch.Tensor, + weights: Optional[torch.Tensor], bins: torch.Tensor, top_k: int, ) -> torch.Tensor: @@ -254,7 +254,12 @@ def repeat(x: torch.Tensor, tiling: Union[torch.Size, Tuple[int, ...]]) -> torch class AllToAllOp(torch.autograd.Function): @staticmethod def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + if output_split_sizes is not None: + out = torch.empty( + (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype + ) + else: + out = torch.empty_like(x) ctx.input_shape = x.shape ctx.output_split_sizes = output_split_sizes @@ -291,8 +296,8 @@ def backward(ctx, grad, _): def all_to_all( x: torch.Tensor, - output_split_sizes: List[int], - input_split_sizes: List[int], + output_split_sizes: Optional[List[int]] = None, + input_split_sizes: Optional[List[int]] = None, group: Optional[dist.ProcessGroup] = None, async_op: bool = False, ) -> Tuple[torch.Tensor, Any]: diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 71e44dabc..a049ddc40 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -8,7 +8,9 @@ import torch.nn as nn from torch.distributed import DeviceMesh -from ...distributed.utils import get_world_size +from olmo_core.distributed.utils import get_world_size +from olmo_core.utils import move_to_device + from . import ops from .mlp import MoEMLP @@ -77,6 +79,8 @@ def indices_and_bins( # Sort the expert ids to produce the scatter/gather # indices for the permutation. # shape: (N,), (N,) + # TODO: for non-dropless MoE, should do secondary sort by expert weight so we drop tokens + # with the lowest expert weight. bin_ids, indices = torch.sort(expert_indices) # Calculate the bin bounds for the sorted items/tokens. @@ -169,11 +173,11 @@ def forward_once( x = self.permute_and_compute( x, - indices, - expert_weights, - bins, - expert_capacity, - top_k, + indices=indices, + expert_weights=expert_weights, + bins=bins, + expert_capacity=expert_capacity, + top_k=top_k, ) return x, batch_size_per_expert @@ -183,26 +187,149 @@ def parallel_forward_once( expert_weights: torch.Tensor, expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - del x, expert_weights, expert_indices - raise NotImplementedError + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignment. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + + num_items, top_k = expert_weights.shape + + # shape: (N * top_k,) + expert_weights = expert_weights.flatten() + # shape: (N * top_k,) + expert_indices = expert_indices.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) + expert_capacity = self.expert_capacity(top_k, num_items) + + # Permute locally so that the tokens for each device are stored contiguously. + # shape: (num_experts, expert_capacity, d_model) + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # TODO: Fuse this into the prior, local permutation? + if self.hidden_sharding_degree > 1: + # shape: (num_local_experts, ep_world_size // hidden_sharding_degree, expert_capacity, d_model) + x = x.view(self.experts_per_rank, -1, expert_capacity, self.d_model) + # shape: (num_experts * hidden_sharding_degree, expert_capacity, d_model) + x = x.repeat(1, self.hidden_sharding_degree, 1, 1).view( + -1, expert_capacity, self.d_model + ) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + # shape: (num_local_experts * ep_world_size, expert_capacity, d_model) + # = (num_experts, expert_capacity, d_model) + parallel_x, parallel_x_handle = ops.all_to_all( + x, + group=self._ep_pg, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + + # Construct the expert indices for the permuted tokens. + # shape: (num_experts,) = (num_local_experts * ep_world_size,) + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=indices.device, + ), + self.experts_per_rank, + ) + + # shape: (num_experts * expert_capacity,) + parallel_top_expert = torch.repeat_interleave( + parallel_top_expert, + expert_capacity, + output_size=parallel_top_expert.numel() * expert_capacity, + ) + + # shape: (num_experts * expert_capacity,) + _, parallel_indices = torch.sort(parallel_top_expert) + + # Calculate the bins boundaries from the token counts. + # shape: (num_local_experts,) + parallel_tokens_per_expert = move_to_device( + torch.tensor([expert_capacity] * self.experts_per_rank), parallel_indices.device + ) + # shape: (num_local_experts,) + parallel_bins = torch.empty_like(parallel_tokens_per_expert, dtype=torch.int32) + torch.cumsum(parallel_tokens_per_expert, 0, out=parallel_bins) + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + indices=parallel_indices.int(), + expert_weights=None, # expert_weights + bins=parallel_bins, + expert_capacity=expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) + + # Reduce along the hidden sharding to get the final outputs. + # TODO: Fuse this into the following local permutation? + x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + return x, tokens_per_expert.flatten() def permute_and_compute( self, x: torch.Tensor, + *, indices: torch.Tensor, - expert_weights: torch.Tensor, + expert_weights: Optional[torch.Tensor], bins: torch.Tensor, expert_capacity: int, top_k: int, ) -> torch.Tensor: - # Route the tokens for MoE computation. + # shape: (N, d_model) x = x.view(-1, x.shape[-1]) + + # Route the tokens for MoE computation. + # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) # Perform the expert computation. - x = self.mlp(x) + # shape: (num_experts, expert_capacity, d_model) + x = self.mlp(x, torch.tensor([expert_capacity] * x.shape[0])) - # Un-route the data for the MoE output. + # Un-route the data for the MoE output. Items that were dropped will be zeroed out. + # shape: (N, d_model) return ops.binned_scatter(x, indices, expert_weights, bins, top_k) @@ -232,12 +359,12 @@ def forward_once( out = self.permute_and_compute( x, - batch_size_per_expert, - indices, - bin_ids, - expert_weights, - bins, - top_k, + batch_size_per_expert=batch_size_per_expert, + indices=indices, + bin_ids=bin_ids, + expert_weights=expert_weights, + bins=bins, + top_k=top_k, ) return out, batch_size_per_expert @@ -316,7 +443,7 @@ def parallel_forward_once( self.ep_world_size, self.experts_per_rank ) - # TODO: can we avoid the host-device sync? + # NOTE: host-device sync here. send_counts = repeated_tokens_per_expert.sum(dim=-1).cpu().tolist() recv_counts = parallel_tokens_per_expert.sum(dim=-1).cpu().tolist() tokens_received = sum(recv_counts) @@ -361,12 +488,6 @@ def parallel_forward_once( parallel_tokens_per_expert.flatten(), output_size=tokens_received, ) - # replicate_bins = torch.cumsum(parallel_tokens_per_expert.flatten(), 0) - # parallel_top_expert = ops.replicate( - # parallel_top_expert.unsqueeze(dim=0), - # replicate_bins, - # tokens_received, - # ).flatten() parallel_bin_ids, parallel_indices = torch.sort(parallel_top_expert) @@ -383,12 +504,12 @@ def parallel_forward_once( parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, - parallel_tokens_per_expert, - parallel_indices.int(), - parallel_bin_ids, - None, # expert_weights - parallel_bins, - 1, + batch_size_per_expert=parallel_tokens_per_expert, + indices=parallel_indices.int(), + bin_ids=parallel_bin_ids, + expert_weights=None, + bins=parallel_bins, + top_k=1, ) # Un-permute the tokens across the devices. @@ -411,7 +532,8 @@ def parallel_forward_once( def permute_and_compute( self, x: torch.Tensor, - tokens_per_expert: torch.Tensor, + *, + batch_size_per_expert: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, expert_weights: Optional[torch.Tensor], @@ -423,7 +545,7 @@ def permute_and_compute( x = ops.gather(x, indices, bin_ids, bins, top_k) # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) + x = self.mlp(x, batch_size_per_expert) # Un-route the data for the MoE output. return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 824589a6d..d8a88a512 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -13,12 +13,11 @@ from olmo_core.utils import get_default_device from ...distributed.utils import requires_multi_gpu, run_distributed_test -from ...utils import requires_gpu, requires_grouped_gemm +from ...utils import requires_gpu @requires_gpu -@requires_grouped_gemm -@pytest.mark.parametrize("moe_type", [MoEType.dropless]) +@pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) def test_moe(moe_type, dtype): d_model = 128 @@ -97,7 +96,6 @@ def run_moe_with_expert_parallelism(moe_type, dtype): @requires_multi_gpu -@requires_grouped_gemm @pytest.mark.parametrize("moe_type", [MoEType.dropless]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) def test_moe_with_expert_parallelism(moe_type, dtype): From ff743e34222450e56b0f28f71e8528e2af6388f2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:46:10 -0800 Subject: [PATCH 054/230] refactor --- src/olmo_core/nn/moe/__init__.py | 3 +- src/olmo_core/nn/moe/mlp.py | 173 +++++++++++++++++++-------- src/olmo_core/nn/moe/moe.py | 62 ++++++++-- src/olmo_core/nn/moe/parallel_mlp.py | 7 +- src/test/nn/moe/mlp_test.py | 50 +++++++- src/test/nn/moe/moe_test.py | 6 +- 6 files changed, 235 insertions(+), 66 deletions(-) diff --git a/src/olmo_core/nn/moe/__init__.py b/src/olmo_core/nn/moe/__init__.py index ce3aad46f..479b3b4ed 100644 --- a/src/olmo_core/nn/moe/__init__.py +++ b/src/olmo_core/nn/moe/__init__.py @@ -2,7 +2,7 @@ MoE layers. """ -from .mlp import MoEMLP, MoEMLPConfig, MoEMLPType +from .mlp import DroplessMoEMLP, MoEMLP, MoEMLPConfig, MoEMLPType from .moe import DroplessMoE, MoEBase, MoEConfig, MoEType from .router import MoELinearRouter, MoERouter, MoERouterConfig, MoERouterType from .shared_mlp import SharedMLP, SharedMLPConfig, SharedMLPType @@ -13,6 +13,7 @@ "MoEConfig", "MoEType", "MoEMLP", + "DroplessMoEMLP", "MoEMLPConfig", "MoEMLPType", "SharedMLP", diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index d667db1c3..451f676c9 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -41,14 +41,14 @@ class MoEMLPType(StrEnum): ➡️ :class:`MoEMLP` """ - -@dataclass -class MoEMLPConfig(Config): - name: MoEMLPType = MoEMLPType.default + dropless = "dropless" """ - The name of the implementation. + ➡️ :class:`DroplessMoEMLP` """ + +@dataclass +class MoEMLPConfig(Config): dtype: DType = DType.float32 def num_params(self, d_model: int, num_experts: int, hidden_size: int) -> int: @@ -59,22 +59,21 @@ def num_params(self, d_model: int, num_experts: int, hidden_size: int) -> int: :param num_experts: Then number of experts. :param hidden_size: The hidden size of each expert. """ - num_params = 0 - if self.name == MoEMLPType.default: - num_params += 3 * d_model * hidden_size * num_experts - else: - raise NotImplementedError - - return num_params + return 3 * d_model * hidden_size * num_experts def num_active_params(self, d_model: int, top_k: int, hidden_size: int) -> int: return self.num_params(d_model, top_k, hidden_size) def build( - self, d_model: int, num_experts: int, hidden_size: int, *, init_device: str = "cpu" - ) -> "MoEMLP": + self, + *, + name: MoEMLPType, + d_model: int, + num_experts: int, + hidden_size: int, + init_device: str = "cpu", + ) -> "MoEMLPBase": kwargs = self.as_dict(exclude_none=True, recurse=False) - kwargs.pop("name") kwargs.update( dtype=kwargs.pop("dtype").as_pt(), d_model=d_model, @@ -84,29 +83,25 @@ def build( ) try: - if self.name == MoEMLPType.default: + if name == MoEMLPType.default: return MoEMLP(**kwargs) + elif name == MoEMLPType.dropless: + return DroplessMoEMLP(**kwargs) else: - raise NotImplementedError(self.name) + raise NotImplementedError(name) except TypeError as e: raise OLMoConfigurationError( - f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + f"invalid options for '{name}' {self.__class__.__name__}, {e}" ) from e -class MoEMLP(nn.Module): - """ - A basic expert MLP module with SwiGLU activation. - """ - +class MoEMLPBase(nn.Module): def __init__( self, *, d_model: int, hidden_size: int, num_experts: int, - dtype: torch.dtype = torch.float32, - init_device: str = "cpu", ): super().__init__() self.d_model = d_model @@ -117,6 +112,110 @@ def __init__( self.experts_per_rank = num_experts self.hidden_sharding_degree = 1 + def scale_grad(self, w: torch.Tensor) -> torch.Tensor: + if self.gradient_scale is None: + return w + return _scale_gradient(w, self.gradient_scale) + + def apply_ep(self, ep_mesh: DeviceMesh): + """ + Apply expert parallelism. + """ + if ep_mesh.ndim > 1: + raise RuntimeError("local expert parallel sub-mesh must be 1-dimensional") + num_shards = ep_mesh.size() + if self.num_experts % num_shards != 0: + raise OLMoConfigurationError( + f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel shard degree ({num_shards})." + ) + + self.experts_per_rank = self.num_experts // num_shards + self.gradient_scale = 1.0 / num_shards + + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) + self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) + self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)]))) + + +class MoEMLP(MoEMLPBase): + """ + A basic expert MLP module with SwiGLU activation. + """ + + def __init__( + self, + *, + d_model: int, + hidden_size: int, + num_experts: int, + dtype: torch.dtype = torch.float32, + init_device: str = "cpu", + ): + super().__init__(d_model=d_model, hidden_size=hidden_size, num_experts=num_experts) + self.w1 = nn.Parameter( + torch.empty( + num_experts, + d_model, + hidden_size, + device=init_device, + dtype=dtype, + ), + ) + self.w2 = nn.Parameter( + torch.empty( + num_experts, + d_model, + hidden_size, + device=init_device, + dtype=dtype, + ), + ) + self.w3 = nn.Parameter( + torch.empty( + num_experts, + hidden_size, + d_model, + device=init_device, + dtype=dtype, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute the expert outputs. + + :param x: The input of shape ``(*, d_model)``. + """ + # Scale gradients and get local tensors (in case of expert parallelism). + # shape (all): (experts_per_rank, hidden_size, d_model) + w1, w2, w3 = ( + get_local_tensor(self.scale_grad(self.w1)), + get_local_tensor(self.scale_grad(self.w2)), + get_local_tensor(self.scale_grad(self.w3)), + ) + + # Compute the MLP. + x1 = torch.bmm(x, w1) + x2 = torch.bmm(x, w3) + x1 = F.silu(x1) * x2 + return torch.bmm(x1, w2) + + +class DroplessMoEMLP(MoEMLPBase): + """ + A dropless expert MLP module with SwiGLU activation. + """ + + def __init__( + self, + *, + d_model: int, + hidden_size: int, + num_experts: int, + dtype: torch.dtype = torch.float32, + init_device: str = "cpu", + ): + super().__init__(d_model=d_model, hidden_size=hidden_size, num_experts=num_experts) self.w1 = nn.Parameter( torch.empty( num_experts, @@ -158,11 +257,6 @@ def __init__( "https://github.com/tgale96/grouped_gemm" ) - def scale_grad(self, w: torch.Tensor) -> torch.Tensor: - if self.gradient_scale is None: - return w - return _scale_gradient(w, self.gradient_scale) - def gmm( self, x: torch.Tensor, w: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool = False ) -> torch.Tensor: @@ -198,22 +292,3 @@ def forward(self, x: torch.Tensor, batch_size_per_expert: torch.Tensor) -> torch x2 = self.gmm(x, w3, batch_size_per_expert, trans_b=True) x1 = F.silu(x1) * x2 return self.gmm(x1, w2, batch_size_per_expert) - - def apply_ep(self, ep_mesh: DeviceMesh): - """ - Apply expert parallelism. - """ - if ep_mesh.ndim > 1: - raise RuntimeError("local expert parallel sub-mesh must be 1-dimensional") - num_shards = ep_mesh.size() - if self.num_experts % num_shards != 0: - raise OLMoConfigurationError( - f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel shard degree ({num_shards})." - ) - - self.experts_per_rank = self.num_experts // num_shards - self.gradient_scale = 1.0 / num_shards - - self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) - self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) - self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)]))) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index ad0cae171..82193606d 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -1,6 +1,6 @@ from abc import abstractmethod from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast import torch import torch.nn as nn @@ -13,7 +13,7 @@ from olmo_core.exceptions import OLMoConfigurationError from .loss import MoELoadBalancingLoss, MoELoss, MoERouterZLoss -from .mlp import MoEMLP, MoEMLPConfig +from .mlp import DroplessMoEMLP, MoEMLP, MoEMLPConfig, MoEMLPType from .parallel_mlp import ParallelDroplessMLP, ParallelMLP, ParallelMLPBase from .router import MoERouterConfig from .shared_mlp import SharedMLPConfig @@ -114,7 +114,7 @@ def __init__( super().__init__() self.router = router.build(d_model, num_experts, init_device=init_device) self.experts = self._init_parallel_mlp( - mlp.build(d_model, num_experts, hidden_size, init_device=init_device), **kwargs + mlp, d_model, num_experts, hidden_size, init_device, **kwargs ) self.shared_experts = ( None @@ -152,7 +152,15 @@ def reset_losses(self): loss_fn.reset() @abstractmethod - def _init_parallel_mlp(self, mlp: MoEMLP, **kwargs) -> ParallelMLPBase: + def _init_parallel_mlp( + self, + mlp: MoEMLPConfig, + d_model: int, + num_experts: int, + hidden_size: int, + init_device: str = "cpu", + **kwargs, + ) -> ParallelMLPBase: raise NotImplementedError def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -239,8 +247,28 @@ def __init__( capacity_factor=capacity_factor, ) - def _init_parallel_mlp(self, mlp: MoEMLP, *, capacity_factor: float) -> ParallelMLP: # type: ignore[override] - return ParallelMLP(mlp=mlp, capacity_factor=capacity_factor) + def _init_parallel_mlp( + self, + mlp: MoEMLPConfig, + d_model: int, + num_experts: int, + hidden_size: int, + capacity_factor: float, + init_device: str = "cpu", + ) -> ParallelMLP: + return ParallelMLP( + mlp=cast( + MoEMLP, + mlp.build( + name=MoEMLPType.default, + d_model=d_model, + num_experts=num_experts, + hidden_size=hidden_size, + init_device=init_device, + ), + ), + capacity_factor=capacity_factor, + ) class DroplessMoE(MoEBase): @@ -248,5 +276,23 @@ class DroplessMoE(MoEBase): A dropless MoE implementation. """ - def _init_parallel_mlp(self, mlp: MoEMLP) -> ParallelDroplessMLP: # type: ignore[override] - return ParallelDroplessMLP(mlp=mlp) + def _init_parallel_mlp( + self, + mlp: MoEMLPConfig, + d_model: int, + num_experts: int, + hidden_size: int, + init_device: str = "cpu", + ) -> ParallelDroplessMLP: + return ParallelDroplessMLP( + mlp=cast( + DroplessMoEMLP, + mlp.build( + name=MoEMLPType.dropless, + d_model=d_model, + num_experts=num_experts, + hidden_size=hidden_size, + init_device=init_device, + ), + ), + ) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index a049ddc40..206cfaba1 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -12,7 +12,7 @@ from olmo_core.utils import move_to_device from . import ops -from .mlp import MoEMLP +from .mlp import DroplessMoEMLP, MoEMLP, MoEMLPBase __all__ = ["ParallelMLPBase", "ParallelMLP", "ParallelDroplessMLP"] @@ -22,7 +22,7 @@ class ParallelMLPBase(nn.Module): Wraps an MoE MLP layer to coordinate the routing and expert parallelism. """ - def __init__(self, *, mlp: MoEMLP): + def __init__(self, *, mlp: MoEMLPBase): super().__init__() self.mlp = mlp self._expert_parallel_enabled: bool = False @@ -341,6 +341,9 @@ class ParallelDroplessMLP(ParallelMLPBase): When expert parallelism is enabled the forward pass involves a host-device sync. """ + def __init__(self, *, mlp: DroplessMoEMLP): + super().__init__(mlp=mlp) + def forward_once( self, x: torch.Tensor, diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index 767617665..5c66253d9 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -6,11 +6,11 @@ build_expert_parallel_mesh, ) from olmo_core.distributed.utils import get_local_tensor -from olmo_core.nn.moe.mlp import MoEMLP +from olmo_core.nn.moe.mlp import DroplessMoEMLP, MoEMLP from olmo_core.utils import get_default_device from ...distributed.utils import requires_multi_gpu, run_distributed_test -from ...utils import requires_gpu +from ...utils import requires_gpu, requires_grouped_gemm @requires_gpu @@ -18,6 +18,18 @@ def test_mlp(): mlp = MoEMLP( d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 ) + x = torch.randn(6, 128, device="cuda", dtype=torch.bfloat16) + tokens_per_expert = torch.tensor([3, 3], device="cuda") + out = mlp(x, tokens_per_expert) + assert out.shape == (6, 128) + + +@requires_gpu +@requires_grouped_gemm +def test_dropless_mlp(): + mlp = DroplessMoEMLP( + d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 + ) x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) tokens_per_expert = torch.tensor([3, 2], device="cuda") out = mlp(x, tokens_per_expert) @@ -38,13 +50,41 @@ def run_mlp_with_expert_parallelism(): mlp.to_empty(device=get_default_device()) assert get_local_tensor(mlp.w1).shape == (2, 256, 128) - x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) - tokens_per_expert = torch.tensor([3, 2], device="cuda") + x = torch.randn(6, 128, device="cuda", dtype=torch.bfloat16) + tokens_per_expert = torch.tensor([3, 3], device="cuda") out = mlp(x, tokens_per_expert) - assert out.shape == (5, 128) + assert out.shape == (6, 128) @requires_multi_gpu def test_mlp_with_expert_parallelism(): run_distributed_test(run_mlp_with_expert_parallelism, backend="nccl", start_method="spawn") + + +def run_dropless_mlp_with_expert_parallelism(): + ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) + + mlp = MoEMLP( + d_model=128, + hidden_size=256, + num_experts=dist.get_world_size() * 2, + init_device="meta", + dtype=torch.bfloat16, + ) + mlp.apply_ep(ep_mesh) + mlp.to_empty(device=get_default_device()) + assert get_local_tensor(mlp.w1).shape == (2, 256, 128) + + x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) + tokens_per_expert = torch.tensor([2, 3], device="cuda") + out = mlp(x, tokens_per_expert) + + assert out.shape == (5, 128) + + +@requires_multi_gpu +def test_dropless_mlp_with_expert_parallelism(): + run_distributed_test( + run_dropless_mlp_with_expert_parallelism, backend="nccl", start_method="spawn" + ) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index d8a88a512..3da54de88 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -10,7 +10,7 @@ build_expert_parallel_mesh, ) from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType -from olmo_core.utils import get_default_device +from olmo_core.utils import get_default_device, seed_all from ...distributed.utils import requires_multi_gpu, run_distributed_test from ...utils import requires_gpu @@ -20,6 +20,8 @@ @pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) def test_moe(moe_type, dtype): + seed_all(42) + d_model = 128 config = MoEConfig( name=moe_type, @@ -61,6 +63,8 @@ def test_moe(moe_type, dtype): def run_moe_with_expert_parallelism(moe_type, dtype): + seed_all(42) + ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=min(dist.get_world_size(), 2))) d_model = 128 From aa7da18f342ec3fd6e91331aa09af20db69b5ca6 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:49:26 -0800 Subject: [PATCH 055/230] fix --- src/test/nn/moe/mlp_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index 5c66253d9..605264a07 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -48,7 +48,7 @@ def run_mlp_with_expert_parallelism(): ) mlp.apply_ep(ep_mesh) mlp.to_empty(device=get_default_device()) - assert get_local_tensor(mlp.w1).shape == (2, 256, 128) + assert get_local_tensor(mlp.w1).shape == (2, 128, 256) x = torch.randn(6, 128, device="cuda", dtype=torch.bfloat16) tokens_per_expert = torch.tensor([3, 3], device="cuda") @@ -65,7 +65,7 @@ def test_mlp_with_expert_parallelism(): def run_dropless_mlp_with_expert_parallelism(): ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) - mlp = MoEMLP( + mlp = DroplessMoEMLP( d_model=128, hidden_size=256, num_experts=dist.get_world_size() * 2, @@ -84,6 +84,7 @@ def run_dropless_mlp_with_expert_parallelism(): @requires_multi_gpu +@requires_grouped_gemm def test_dropless_mlp_with_expert_parallelism(): run_distributed_test( run_dropless_mlp_with_expert_parallelism, backend="nccl", start_method="spawn" From 76c5ec2b70d2266e09870f38d042e67fb1b52024 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:51:27 -0800 Subject: [PATCH 056/230] fix --- src/olmo_core/nn/moe/mlp.py | 4 ++-- src/test/nn/moe/mlp_test.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 451f676c9..b10eb6ac4 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -12,7 +12,7 @@ from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError -__all__ = ["MoEMLP", "MoEMLPConfig", "MoEMLPType"] +__all__ = ["MoEMLP", "DroplessMoEMLP", "MoEMLPConfig", "MoEMLPType"] class _ScaleGradient(torch.autograd.Function): @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute the expert outputs. - :param x: The input of shape ``(*, d_model)``. + :param x: The input of shape ``(num_local_experts, N, d_model)``. """ # Scale gradients and get local tensors (in case of expert parallelism). # shape (all): (experts_per_rank, hidden_size, d_model) diff --git a/src/test/nn/moe/mlp_test.py b/src/test/nn/moe/mlp_test.py index 605264a07..f357406b3 100644 --- a/src/test/nn/moe/mlp_test.py +++ b/src/test/nn/moe/mlp_test.py @@ -18,10 +18,9 @@ def test_mlp(): mlp = MoEMLP( d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 ) - x = torch.randn(6, 128, device="cuda", dtype=torch.bfloat16) - tokens_per_expert = torch.tensor([3, 3], device="cuda") - out = mlp(x, tokens_per_expert) - assert out.shape == (6, 128) + x = torch.randn(2, 3, 128, device="cuda", dtype=torch.bfloat16) + out = mlp(x) + assert out.shape == (2, 3, 128) @requires_gpu @@ -50,11 +49,10 @@ def run_mlp_with_expert_parallelism(): mlp.to_empty(device=get_default_device()) assert get_local_tensor(mlp.w1).shape == (2, 128, 256) - x = torch.randn(6, 128, device="cuda", dtype=torch.bfloat16) - tokens_per_expert = torch.tensor([3, 3], device="cuda") - out = mlp(x, tokens_per_expert) + x = torch.randn(2, 3, 128, device="cuda", dtype=torch.bfloat16) + out = mlp(x) - assert out.shape == (6, 128) + assert out.shape == (2, 3, 128) @requires_multi_gpu From 8fb887d296ac086bda0bb99e9b56579a13b912d6 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:53:09 -0800 Subject: [PATCH 057/230] fix --- src/olmo_core/nn/moe/mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index b10eb6ac4..fc1538583 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -164,8 +164,8 @@ def __init__( self.w2 = nn.Parameter( torch.empty( num_experts, - d_model, hidden_size, + d_model, device=init_device, dtype=dtype, ), @@ -173,8 +173,8 @@ def __init__( self.w3 = nn.Parameter( torch.empty( num_experts, - hidden_size, d_model, + hidden_size, device=init_device, dtype=dtype, ), From 6951bb1f3de7a609ef5f5e6d20d3f476f767e553 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:55:20 -0800 Subject: [PATCH 058/230] clean up --- src/olmo_core/nn/moe/mlp.py | 6 +++--- src/olmo_core/nn/moe/moe.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index fc1538583..c2c4c4a0c 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -132,9 +132,9 @@ def apply_ep(self, ep_mesh: DeviceMesh): self.experts_per_rank = self.num_experts // num_shards self.gradient_scale = 1.0 / num_shards - self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) - self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) - self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)]))) + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) # type: ignore + self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) # type: ignore + self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)]))) # type: ignore class MoEMLP(MoEMLPBase): diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 82193606d..05eb19050 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -247,7 +247,7 @@ def __init__( capacity_factor=capacity_factor, ) - def _init_parallel_mlp( + def _init_parallel_mlp( # type: ignore[override] self, mlp: MoEMLPConfig, d_model: int, @@ -276,7 +276,7 @@ class DroplessMoE(MoEBase): A dropless MoE implementation. """ - def _init_parallel_mlp( + def _init_parallel_mlp( # type: ignore[override] self, mlp: MoEMLPConfig, d_model: int, From c8be5dcf70d774ca6367b3b10eeb229737372646 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 15:59:18 -0800 Subject: [PATCH 059/230] fix --- src/olmo_core/nn/moe/moe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 05eb19050..5ce7c025e 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -114,7 +114,12 @@ def __init__( super().__init__() self.router = router.build(d_model, num_experts, init_device=init_device) self.experts = self._init_parallel_mlp( - mlp, d_model, num_experts, hidden_size, init_device, **kwargs + mlp, + d_model=d_model, + num_experts=num_experts, + hidden_size=hidden_size, + init_device=init_device, + **kwargs, ) self.shared_experts = ( None @@ -155,6 +160,7 @@ def reset_losses(self): def _init_parallel_mlp( self, mlp: MoEMLPConfig, + *, d_model: int, num_experts: int, hidden_size: int, @@ -250,6 +256,7 @@ def __init__( def _init_parallel_mlp( # type: ignore[override] self, mlp: MoEMLPConfig, + *, d_model: int, num_experts: int, hidden_size: int, @@ -279,6 +286,7 @@ class DroplessMoE(MoEBase): def _init_parallel_mlp( # type: ignore[override] self, mlp: MoEMLPConfig, + *, d_model: int, num_experts: int, hidden_size: int, From 774b77ffd78b0bffff83a4fdf39527b53b9087ec Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 16:00:56 -0800 Subject: [PATCH 060/230] fix --- src/olmo_core/nn/moe/parallel_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 206cfaba1..4c17d2fa7 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -326,7 +326,7 @@ def permute_and_compute( # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) - x = self.mlp(x, torch.tensor([expert_capacity] * x.shape[0])) + x = self.mlp(x) # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) From 7ecebd28234a137071eec741840bbbcc8e01fda8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 16:02:09 -0800 Subject: [PATCH 061/230] add parallel test for default --- src/test/nn/moe/moe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 3da54de88..fc9bc8983 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -100,7 +100,7 @@ def run_moe_with_expert_parallelism(moe_type, dtype): @requires_multi_gpu -@pytest.mark.parametrize("moe_type", [MoEType.dropless]) +@pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) def test_moe_with_expert_parallelism(moe_type, dtype): run_distributed_test( From 6841a6af8dd94a22a1668e0ffa8ed79007d0b7d3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 3 Feb 2025 16:04:48 -0800 Subject: [PATCH 062/230] fix? --- src/olmo_core/nn/moe/parallel_mlp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 4c17d2fa7..1184c4a26 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -206,6 +206,8 @@ def parallel_forward_once( # After this series of permutations we compute the linear layers # and then repeat these three steps in reverse to produce the final # output. + # shape: (N, d_model) + x = x.view(-1, x.shape[-1]) num_items, top_k = expert_weights.shape From 62f5496d7118e3572ff43bbcf1eeb3eef0c79ab9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 11:38:26 -0800 Subject: [PATCH 063/230] clean up --- src/olmo_core/nn/moe/parallel_mlp.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 1184c4a26..c47238a72 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -380,25 +380,9 @@ def parallel_forward_once( expert_weights: torch.Tensor, expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignment. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. + # NOTE: This function does the same thing as `ParallelMLP.parallel_forward_once()` + # but with extra bookkeeping to manage the dynamic sizes, and unfortunately this introduces + # a host-device sync. top_k = expert_weights.shape[-1] From 45482d8e79fc5e350c0851b0ddf4c19880cbf50f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:03:23 -0800 Subject: [PATCH 064/230] clean up --- src/olmo_core/nn/moe/__init__.py | 4 +- src/olmo_core/nn/moe/mlp.py | 68 +--------------------------- src/olmo_core/nn/moe/moe.py | 65 ++++++++++++-------------- src/olmo_core/nn/moe/parallel_mlp.py | 3 +- src/olmo_core/nn/moe/router.py | 16 +++++-- src/olmo_core/nn/moe/shared_mlp.py | 17 +++++-- src/scripts/train/OLMoE-1B-7B.py | 3 +- src/test/nn/moe/moe_test.py | 6 +-- 8 files changed, 64 insertions(+), 118 deletions(-) diff --git a/src/olmo_core/nn/moe/__init__.py b/src/olmo_core/nn/moe/__init__.py index 479b3b4ed..f249dc6cf 100644 --- a/src/olmo_core/nn/moe/__init__.py +++ b/src/olmo_core/nn/moe/__init__.py @@ -2,7 +2,7 @@ MoE layers. """ -from .mlp import DroplessMoEMLP, MoEMLP, MoEMLPConfig, MoEMLPType +from .mlp import DroplessMoEMLP, MoEMLP from .moe import DroplessMoE, MoEBase, MoEConfig, MoEType from .router import MoELinearRouter, MoERouter, MoERouterConfig, MoERouterType from .shared_mlp import SharedMLP, SharedMLPConfig, SharedMLPType @@ -14,8 +14,6 @@ "MoEType", "MoEMLP", "DroplessMoEMLP", - "MoEMLPConfig", - "MoEMLPType", "SharedMLP", "SharedMLPConfig", "SharedMLPType", diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index c2c4c4a0c..899994742 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -1,5 +1,4 @@ import warnings -from dataclasses import dataclass from typing import Any, Callable, Optional import torch @@ -8,11 +7,10 @@ from torch.distributed import DeviceMesh from torch.distributed.tensor import Shard, distribute_tensor -from ...config import Config, DType, StrEnum from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError -__all__ = ["MoEMLP", "DroplessMoEMLP", "MoEMLPConfig", "MoEMLPType"] +__all__ = ["MoEMLP", "DroplessMoEMLP"] class _ScaleGradient(torch.autograd.Function): @@ -31,70 +29,6 @@ def backward(ctx: torch.Tensor, grad: torch.Tensor): _scale_gradient: Callable[[torch.Tensor, float], torch.Tensor] = _ScaleGradient.apply # type: ignore -class MoEMLPType(StrEnum): - """ - An enumeration of the different MoE expert MLP implementations. - """ - - default = "default" - """ - ➡️ :class:`MoEMLP` - """ - - dropless = "dropless" - """ - ➡️ :class:`DroplessMoEMLP` - """ - - -@dataclass -class MoEMLPConfig(Config): - dtype: DType = DType.float32 - - def num_params(self, d_model: int, num_experts: int, hidden_size: int) -> int: - """ - The number of params that the module will have once built. - - :param d_model: The model dimensionality. - :param num_experts: Then number of experts. - :param hidden_size: The hidden size of each expert. - """ - return 3 * d_model * hidden_size * num_experts - - def num_active_params(self, d_model: int, top_k: int, hidden_size: int) -> int: - return self.num_params(d_model, top_k, hidden_size) - - def build( - self, - *, - name: MoEMLPType, - d_model: int, - num_experts: int, - hidden_size: int, - init_device: str = "cpu", - ) -> "MoEMLPBase": - kwargs = self.as_dict(exclude_none=True, recurse=False) - kwargs.update( - dtype=kwargs.pop("dtype").as_pt(), - d_model=d_model, - num_experts=num_experts, - hidden_size=hidden_size, - init_device=init_device, - ) - - try: - if name == MoEMLPType.default: - return MoEMLP(**kwargs) - elif name == MoEMLPType.dropless: - return DroplessMoEMLP(**kwargs) - else: - raise NotImplementedError(name) - except TypeError as e: - raise OLMoConfigurationError( - f"invalid options for '{name}' {self.__class__.__name__}, {e}" - ) from e - - class MoEMLPBase(nn.Module): def __init__( self, diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 5ce7c025e..f4704ea85 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -1,6 +1,6 @@ from abc import abstractmethod from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Union import torch import torch.nn as nn @@ -8,12 +8,12 @@ from torch.distributed.tensor import Placement, Replicate, Shard from torch.distributed.tensor.parallel import PrepareModuleOutput, parallelize_module -from olmo_core.config import Config, StrEnum +from olmo_core.config import Config, DType, StrEnum from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel from olmo_core.exceptions import OLMoConfigurationError from .loss import MoELoadBalancingLoss, MoELoss, MoERouterZLoss -from .mlp import DroplessMoEMLP, MoEMLP, MoEMLPConfig, MoEMLPType +from .mlp import DroplessMoEMLP, MoEMLP from .parallel_mlp import ParallelDroplessMLP, ParallelMLP, ParallelMLPBase from .router import MoERouterConfig from .shared_mlp import SharedMLPConfig @@ -47,26 +47,24 @@ class MoEConfig(Config): hidden_size: int = 256 capacity_factor: Optional[float] = None router: MoERouterConfig = field(default_factory=MoERouterConfig) - mlp: MoEMLPConfig = field(default_factory=MoEMLPConfig) shared_mlp: Optional[SharedMLPConfig] = None lb_loss_weight: Optional[float] = 1.0 z_loss_weight: Optional[float] = None + dtype: DType = DType.float32 def num_params(self, d_model: int) -> int: num_params = 0 - num_params += self.router.num_params(d_model, self.num_experts) - num_params += self.mlp.num_params(d_model, self.num_experts, self.hidden_size) + num_params += 3 * d_model * self.hidden_size * self.num_experts if self.shared_mlp is not None: num_params += self.shared_mlp.num_params(d_model, self.hidden_size) - return num_params def num_active_params(self, d_model: int) -> int: return ( self.num_params(d_model) - - self.mlp.num_params(d_model, self.num_experts, self.hidden_size) - + self.mlp.num_active_params(d_model, self.router.top_k, self.hidden_size) + - (3 * d_model * self.hidden_size * self.num_experts) + + (3 * d_model * self.hidden_size * self.router.top_k) ) def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> "MoEBase": @@ -76,6 +74,7 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " d_model=d_model, num_layers=num_layers, init_device=init_device, + dtype=kwargs.pop("dtype").as_pt(), ) try: @@ -103,28 +102,28 @@ def __init__( num_experts: int, hidden_size: int, router: MoERouterConfig, - mlp: MoEMLPConfig, num_layers: int, shared_mlp: Optional[SharedMLPConfig] = None, init_device: str = "cpu", lb_loss_weight: Optional[float] = None, z_loss_weight: Optional[float] = None, + dtype: torch.dtype = torch.float32, **kwargs, ): super().__init__() - self.router = router.build(d_model, num_experts, init_device=init_device) + self.router = router.build(d_model, num_experts, dtype=dtype, init_device=init_device) self.experts = self._init_parallel_mlp( - mlp, d_model=d_model, num_experts=num_experts, hidden_size=hidden_size, + dtype=dtype, init_device=init_device, **kwargs, ) self.shared_experts = ( None if shared_mlp is None - else shared_mlp.build(d_model, hidden_size, init_device=init_device) + else shared_mlp.build(d_model, hidden_size, dtype=dtype, init_device=init_device) ) self.num_layers = num_layers self.losses: List[MoELoss] = [] @@ -159,11 +158,11 @@ def reset_losses(self): @abstractmethod def _init_parallel_mlp( self, - mlp: MoEMLPConfig, *, d_model: int, num_experts: int, hidden_size: int, + dtype: torch.dtype = torch.float32, init_device: str = "cpu", **kwargs, ) -> ParallelMLPBase: @@ -231,48 +230,45 @@ def __init__( num_experts: int, hidden_size: int, router: MoERouterConfig, - mlp: MoEMLPConfig, num_layers: int, shared_mlp: Optional[SharedMLPConfig] = None, capacity_factor: float = 1.2, init_device: str = "cpu", lb_loss_weight: Optional[float] = None, z_loss_weight: Optional[float] = None, + dtype: torch.dtype = torch.float32, ): super().__init__( d_model=d_model, num_experts=num_experts, hidden_size=hidden_size, router=router, - mlp=mlp, num_layers=num_layers, shared_mlp=shared_mlp, init_device=init_device, lb_loss_weight=lb_loss_weight, z_loss_weight=z_loss_weight, + dtype=dtype, capacity_factor=capacity_factor, ) def _init_parallel_mlp( # type: ignore[override] self, - mlp: MoEMLPConfig, *, d_model: int, num_experts: int, hidden_size: int, capacity_factor: float, + dtype: torch.dtype = torch.float32, init_device: str = "cpu", ) -> ParallelMLP: return ParallelMLP( - mlp=cast( - MoEMLP, - mlp.build( - name=MoEMLPType.default, - d_model=d_model, - num_experts=num_experts, - hidden_size=hidden_size, - init_device=init_device, - ), + mlp=MoEMLP( + d_model=d_model, + hidden_size=hidden_size, + num_experts=num_experts, + dtype=dtype, + init_device=init_device, ), capacity_factor=capacity_factor, ) @@ -285,22 +281,19 @@ class DroplessMoE(MoEBase): def _init_parallel_mlp( # type: ignore[override] self, - mlp: MoEMLPConfig, *, d_model: int, num_experts: int, hidden_size: int, + dtype: torch.dtype = torch.float32, init_device: str = "cpu", ) -> ParallelDroplessMLP: return ParallelDroplessMLP( - mlp=cast( - DroplessMoEMLP, - mlp.build( - name=MoEMLPType.dropless, - d_model=d_model, - num_experts=num_experts, - hidden_size=hidden_size, - init_device=init_device, - ), + mlp=DroplessMoEMLP( + d_model=d_model, + num_experts=num_experts, + hidden_size=hidden_size, + dtype=dtype, + init_device=init_device, ), ) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index c47238a72..385ae54ec 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -1,4 +1,5 @@ -# Adapted from 'https://github.com/databricks/megablocks/blob/main/megablocks/layers/moe.py' and 'dmoe.py' +# This code was originally adapted from 'https://github.com/databricks/megablocks/blob/main/megablocks/layers/moe.py'. +# It has since changed substantially. from abc import abstractmethod from typing import Optional, Tuple diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 6579a9464..14f7a9c6b 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -55,7 +55,7 @@ class MoERouterConfig(Config): normalize_expert_weights: Optional[float] = None uniform_expert_assignment: bool = False bias: bool = True - dtype: DType = DType.float32 + dtype: Optional[DType] = None def num_params(self, d_model: int, num_experts: int) -> int: """ @@ -73,7 +73,14 @@ def num_params(self, d_model: int, num_experts: int) -> int: return num_params - def build(self, d_model: int, num_experts, *, init_device: str = "cpu") -> "MoERouter": + def build( + self, + d_model: int, + num_experts, + *, + dtype: Optional[torch.dtype] = None, + init_device: str = "cpu", + ) -> "MoERouter": """ Build the corresponding MoE router module. @@ -84,11 +91,14 @@ def build(self, d_model: int, num_experts, *, init_device: str = "cpu") -> "MoER kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( - dtype=kwargs.pop("dtype").as_pt(), d_model=d_model, num_experts=num_experts, init_device=init_device, ) + if self.dtype is not None: + kwargs["dtype"] = self.dtype.as_pt() + elif dtype is not None: + kwargs["dtype"] = dtype try: if self.name == MoERouterType.default: diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py index e3f9523a3..dbcf8a277 100644 --- a/src/olmo_core/nn/moe/shared_mlp.py +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -33,8 +33,8 @@ class SharedMLPConfig(Config): The name of the implementation. """ weighted_sum: bool = True - bias: Optional[bool] = None - dtype: DType = DType.float32 + bias: bool = True + dtype: Optional[DType] = None def num_params(self, d_model: int, hidden_size: int) -> int: """ @@ -50,7 +50,14 @@ def num_params(self, d_model: int, hidden_size: int) -> int: return params - def build(self, d_model: int, hidden_size: int, *, init_device: str = "cpu") -> "SharedMLP": + def build( + self, + d_model: int, + hidden_size: int, + *, + dtype: Optional[torch.dtype] = None, + init_device: str = "cpu", + ) -> "SharedMLP": """ Build the corresponding shared MLP module. @@ -65,6 +72,10 @@ def build(self, d_model: int, hidden_size: int, *, init_device: str = "cpu") -> init_device=init_device, dtype=kwargs.pop("dtype").as_pt(), ) + if self.dtype is not None: + kwargs["dtype"] = self.dtype.as_pt() + elif dtype is not None: + kwargs["dtype"] = dtype try: if self.name == SharedMLPType.default: diff --git a/src/scripts/train/OLMoE-1B-7B.py b/src/scripts/train/OLMoE-1B-7B.py index 8e686f57d..3824bac38 100644 --- a/src/scripts/train/OLMoE-1B-7B.py +++ b/src/scripts/train/OLMoE-1B-7B.py @@ -6,7 +6,7 @@ from olmo_core.config import DType from olmo_core.distributed.parallel import DataParallelType from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType +from olmo_core.nn.moe import MoEConfig, MoERouterConfig, MoEType from olmo_core.nn.transformer import ( TransformerBlockType, TransformerConfig, @@ -36,7 +36,6 @@ def build_model_config(common: CommonComponents) -> TransformerConfig: num_experts=64, hidden_size=int(0.5 * model_config.d_model), router=MoERouterConfig(top_k=8, bias=False), - mlp=MoEMLPConfig(), lb_loss_weight=0.01, z_loss_weight=0.001, ) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index fc9bc8983..9eda0e9d9 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -9,7 +9,7 @@ ExpertParallelConfig, build_expert_parallel_mesh, ) -from olmo_core.nn.moe import MoEConfig, MoEMLPConfig, MoERouterConfig, MoEType +from olmo_core.nn.moe import MoEConfig, MoERouterConfig, MoEType from olmo_core.utils import get_default_device, seed_all from ...distributed.utils import requires_multi_gpu, run_distributed_test @@ -28,8 +28,8 @@ def test_moe(moe_type, dtype): num_experts=4, hidden_size=256, router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), - mlp=MoEMLPConfig(dtype=DType.from_pt(dtype)), z_loss_weight=0.1, + dtype=DType.from_pt(dtype), ) moe = config.build(d_model=d_model, num_layers=1, init_device="cuda") @@ -73,8 +73,8 @@ def run_moe_with_expert_parallelism(moe_type, dtype): num_experts=4, hidden_size=256, router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), - mlp=MoEMLPConfig(dtype=DType.from_pt(dtype)), z_loss_weight=0.1, + dtype=DType.from_pt(dtype), ) moe = config.build(d_model=d_model, num_layers=1, init_device="meta") moe.apply_ep(ep_mesh) From 4f38c0197676f4c9069e837ac44122fc587485a4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:37:09 -0800 Subject: [PATCH 065/230] improve test --- src/test/nn/moe/moe_test.py | 70 ++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 9eda0e9d9..4d6629483 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -1,10 +1,16 @@ import math +from pathlib import Path import pytest import torch import torch.distributed as dist +from torch.distributed.tensor import Shard, distribute_tensor from olmo_core.config import DType +from olmo_core.distributed.checkpoint import ( + load_model_and_optim_state, + save_model_and_optim_state, +) from olmo_core.distributed.parallel import ( ExpertParallelConfig, build_expert_parallel_mesh, @@ -62,50 +68,74 @@ def test_moe(moe_type, dtype): assert x.grad is not None -def run_moe_with_expert_parallelism(moe_type, dtype): +def run_moe_with_expert_parallelism( + checkpoint_dir: Path, + config: MoEConfig, + d_model: int, + batch: torch.Tensor, + expected_output: torch.Tensor, +): seed_all(42) ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=min(dist.get_world_size(), 2))) - d_model = 128 - config = MoEConfig( - name=moe_type, - num_experts=4, - hidden_size=256, - router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), - z_loss_weight=0.1, - dtype=DType.from_pt(dtype), - ) moe = config.build(d_model=d_model, num_layers=1, init_device="meta") moe.apply_ep(ep_mesh) moe.to_empty(device=get_default_device()) - # Run forward pass. - B, S = 2, 16 - x = torch.randn(B, S, d_model, dtype=dtype, device="cuda", requires_grad=True) - - output = moe(x) - assert output.shape == x.shape + # Load checkpoint. + load_model_and_optim_state(checkpoint_dir, moe) - losses = moe.compute_losses(B * S) + # Run forward pass. + total_tokens = batch.shape[0] * batch.shape[1] + batch = batch.cuda().requires_grad_(True) + batch = distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),)) + output = moe(batch) + assert output.shape == batch.shape + torch.testing.assert_close(output, expected_output) + + losses = moe.compute_losses(total_tokens // ep_mesh.size()) lb_loss = losses["load balancing loss"] assert math.isfinite(lb_loss.item()) + z_loss = losses["router Z loss"] assert math.isfinite(z_loss.item()) loss = lb_loss + z_loss # Run backward pass. loss.backward() - assert x.grad is not None + assert batch.grad is not None @requires_multi_gpu @pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) -def test_moe_with_expert_parallelism(moe_type, dtype): +def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: torch.dtype): + seed_all(42) + + d_model = 128 + config = MoEConfig( + name=moe_type, + num_experts=4, + hidden_size=256, + router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), + z_loss_weight=0.1, + dtype=DType.from_pt(dtype), + ) + moe = config.build(d_model=d_model, num_layers=1, init_device="cpu") + moe.to(device=get_default_device()) + + # Save checkpoint. + save_model_and_optim_state(tmp_path, moe) + + B, S = 4, 16 + batch = torch.randn(B, S, d_model, dtype=dtype, device="cuda", requires_grad=True) + output = moe(batch) + assert output.shape == batch.shape + run_distributed_test( run_moe_with_expert_parallelism, backend="nccl", start_method="spawn", - func_args=(moe_type, dtype), + func_args=(tmp_path, config, d_model, batch.detach().cpu(), output.detach().cpu()), ) From 3a1b06c7c74fc488e4ba38519a545989d7d838b4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:38:17 -0800 Subject: [PATCH 066/230] fix --- src/test/nn/moe/moe_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 4d6629483..b00e61721 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -88,11 +88,11 @@ def run_moe_with_expert_parallelism( # Run forward pass. total_tokens = batch.shape[0] * batch.shape[1] - batch = batch.cuda().requires_grad_(True) + batch = batch.to(device=get_default_device()).requires_grad_(True) batch = distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),)) output = moe(batch) assert output.shape == batch.shape - torch.testing.assert_close(output, expected_output) + torch.testing.assert_close(output, expected_output.to(device=output.device)) losses = moe.compute_losses(total_tokens // ep_mesh.size()) lb_loss = losses["load balancing loss"] From f0851713af6d58bc422cfa1c3ea2582355e94074 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:44:22 -0800 Subject: [PATCH 067/230] fix --- src/test/nn/moe/moe_test.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index b00e61721..293623f7c 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -86,10 +86,12 @@ def run_moe_with_expert_parallelism( # Load checkpoint. load_model_and_optim_state(checkpoint_dir, moe) - # Run forward pass. + # Split batch across process group. total_tokens = batch.shape[0] * batch.shape[1] batch = batch.to(device=get_default_device()).requires_grad_(True) batch = distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),)) + + # Run forward pass. output = moe(batch) assert output.shape == batch.shape torch.testing.assert_close(output, expected_output.to(device=output.device)) @@ -113,6 +115,8 @@ def run_moe_with_expert_parallelism( def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: torch.dtype): seed_all(42) + device = torch.device("cuda") + d_model = 128 config = MoEConfig( name=moe_type, @@ -123,13 +127,13 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t dtype=DType.from_pt(dtype), ) moe = config.build(d_model=d_model, num_layers=1, init_device="cpu") - moe.to(device=get_default_device()) + moe.to(device=device) # Save checkpoint. save_model_and_optim_state(tmp_path, moe) B, S = 4, 16 - batch = torch.randn(B, S, d_model, dtype=dtype, device="cuda", requires_grad=True) + batch = torch.randn(B, S, d_model, dtype=dtype, device=device, requires_grad=True) output = moe(batch) assert output.shape == batch.shape From f077870a925b4550e696d550562392f5d9d66c56 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:46:10 -0800 Subject: [PATCH 068/230] fix --- src/test/nn/moe/moe_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 293623f7c..4b6f7a726 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -15,6 +15,7 @@ ExpertParallelConfig, build_expert_parallel_mesh, ) +from olmo_core.distributed.utils import get_local_tensor from olmo_core.nn.moe import MoEConfig, MoERouterConfig, MoEType from olmo_core.utils import get_default_device, seed_all @@ -92,7 +93,7 @@ def run_moe_with_expert_parallelism( batch = distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),)) # Run forward pass. - output = moe(batch) + output = moe(get_local_tensor(batch)) assert output.shape == batch.shape torch.testing.assert_close(output, expected_output.to(device=output.device)) From 3fda27432ee91253635d867bd22c852d077fd611 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:47:11 -0800 Subject: [PATCH 069/230] fix --- src/test/nn/moe/moe_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 4b6f7a726..9318c09a7 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -90,10 +90,10 @@ def run_moe_with_expert_parallelism( # Split batch across process group. total_tokens = batch.shape[0] * batch.shape[1] batch = batch.to(device=get_default_device()).requires_grad_(True) - batch = distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),)) + batch = get_local_tensor(distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),))) # Run forward pass. - output = moe(get_local_tensor(batch)) + output = moe(batch) assert output.shape == batch.shape torch.testing.assert_close(output, expected_output.to(device=output.device)) From 3c3544fa0a0748fa2bb4a6df5c401657c03793d4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:49:07 -0800 Subject: [PATCH 070/230] fix --- src/test/nn/moe/moe_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 9318c09a7..a2e1bb465 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -87,15 +87,22 @@ def run_moe_with_expert_parallelism( # Load checkpoint. load_model_and_optim_state(checkpoint_dir, moe) - # Split batch across process group. + # Split batch and expected output across process group. total_tokens = batch.shape[0] * batch.shape[1] batch = batch.to(device=get_default_device()).requires_grad_(True) batch = get_local_tensor(distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),))) + expected_output = get_local_tensor( + distribute_tensor( + expected_output.to(device=get_default_device()), + device_mesh=ep_mesh, + placements=(Shard(0),), + ) + ) # Run forward pass. output = moe(batch) assert output.shape == batch.shape - torch.testing.assert_close(output, expected_output.to(device=output.device)) + torch.testing.assert_close(output, expected_output) losses = moe.compute_losses(total_tokens // ep_mesh.size()) lb_loss = losses["load balancing loss"] From 29a7fcc3a5580a371302dafef36b3736c6655844 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 12:51:35 -0800 Subject: [PATCH 071/230] fix --- src/test/nn/moe/moe_test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index a2e1bb465..354b218c8 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -89,8 +89,12 @@ def run_moe_with_expert_parallelism( # Split batch and expected output across process group. total_tokens = batch.shape[0] * batch.shape[1] - batch = batch.to(device=get_default_device()).requires_grad_(True) - batch = get_local_tensor(distribute_tensor(batch, device_mesh=ep_mesh, placements=(Shard(0),))) + batch = get_local_tensor( + distribute_tensor( + batch.to(device=get_default_device()), device_mesh=ep_mesh, placements=(Shard(0),) + ) + ) + batch.requires_grad_(True) expected_output = get_local_tensor( distribute_tensor( expected_output.to(device=get_default_device()), @@ -121,6 +125,11 @@ def run_moe_with_expert_parallelism( @pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: torch.dtype): + """ + Test that we get the same result when we run an MoE on a single device as we do when + we run it across multiple devices with expert parallelism. + """ + # This test proceeds as follows. seed_all(42) device = torch.device("cuda") @@ -137,7 +146,6 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t moe = config.build(d_model=d_model, num_layers=1, init_device="cpu") moe.to(device=device) - # Save checkpoint. save_model_and_optim_state(tmp_path, moe) B, S = 4, 16 From 54ae3acaa2c9ddd5b8703af2a86bf01eddd90d63 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:07:00 -0800 Subject: [PATCH 072/230] add more --- pyproject.toml | 1 + src/test/nn/moe/moe_test.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 40f45842e..654087b77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,4 +169,5 @@ filterwarnings = [ 'ignore::DeprecationWarning:pkg_resources', 'ignore::DeprecationWarning:google\.rpc', 'ignore::FutureWarning:torch\.distributed\.checkpoint\.default_planner', + 'ignore::UserWarning:torch\.distributed\.checkpoint\.state_dict_saver', ] diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 354b218c8..2378e703f 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -129,7 +129,6 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t Test that we get the same result when we run an MoE on a single device as we do when we run it across multiple devices with expert parallelism. """ - # This test proceeds as follows. seed_all(42) device = torch.device("cuda") @@ -146,13 +145,28 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t moe = config.build(d_model=d_model, num_layers=1, init_device="cpu") moe.to(device=device) + # Save state so when we spawn distributed processes they can load the same weights. save_model_and_optim_state(tmp_path, moe) + # Create batch and run forward pass. B, S = 4, 16 batch = torch.randn(B, S, d_model, dtype=dtype, device=device, requires_grad=True) output = moe(batch) assert output.shape == batch.shape + # Get losses. + losses = moe.compute_losses(B * S) + lb_loss = losses["load balancing loss"] + assert math.isfinite(lb_loss.item()) + + z_loss = losses["router Z loss"] + assert math.isfinite(z_loss.item()) + loss = lb_loss + z_loss + + # Run backward pass. + loss.backward() + assert batch.grad is not None + run_distributed_test( run_moe_with_expert_parallelism, backend="nccl", From 4c3d0a9d1cda3954b9bb76f1a97035b6028198cf Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:26:24 -0800 Subject: [PATCH 073/230] add --- src/test/nn/moe/moe_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 2378e703f..f9418d150 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -56,6 +56,8 @@ def test_moe(moe_type, dtype): output = moe(x) assert output.shape == x.shape + assert torch.isfinite(output).all() + assert (output > 0).any() losses = moe.compute_losses(B * S) lb_loss = losses["load balancing loss"] @@ -153,6 +155,8 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t batch = torch.randn(B, S, d_model, dtype=dtype, device=device, requires_grad=True) output = moe(batch) assert output.shape == batch.shape + assert torch.isfinite(output).all() + assert (output > 0).any() # Get losses. losses = moe.compute_losses(B * S) From 9fb3d523debdde90a44a6bfed73b2b3c790ad7ff Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:30:55 -0800 Subject: [PATCH 074/230] init weights --- src/test/nn/moe/moe_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index f9418d150..fbd7a65f6 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -16,13 +16,18 @@ build_expert_parallel_mesh, ) from olmo_core.distributed.utils import get_local_tensor -from olmo_core.nn.moe import MoEConfig, MoERouterConfig, MoEType +from olmo_core.nn.moe import MoEBase, MoEConfig, MoERouterConfig, MoEType from olmo_core.utils import get_default_device, seed_all from ...distributed.utils import requires_multi_gpu, run_distributed_test from ...utils import requires_gpu +def init_mlp_weights(moe: MoEBase): + for w in (moe.experts.mlp.w1, moe.experts.mlp.w2, moe.experts.mlp.w2): + torch.nn.init.normal_(w, std=0.02) # type: ignore + + @requires_gpu @pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) @@ -39,6 +44,7 @@ def test_moe(moe_type, dtype): dtype=DType.from_pt(dtype), ) moe = config.build(d_model=d_model, num_layers=1, init_device="cuda") + init_mlp_weights(moe) # Check num params calculation. num_params = 0 @@ -146,6 +152,7 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t ) moe = config.build(d_model=d_model, num_layers=1, init_device="cpu") moe.to(device=device) + init_mlp_weights(moe) # Save state so when we spawn distributed processes they can load the same weights. save_model_and_optim_state(tmp_path, moe) From dc5c13ffe1ec5cc14ce173b9e4913221867b5d6d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:40:41 -0800 Subject: [PATCH 075/230] add extra repr to MLP class --- src/olmo_core/nn/moe/mlp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 899994742..c51bd0921 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -114,6 +114,9 @@ def __init__( ), ) + def extra_repr(self): + return f"d_model={self.d_model}, num_experts={self.num_experts}, hidden_size={self.hidden_size}" + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute the expert outputs. From 97bf6182e091f6aa5e3aa992d6ad579ef1af641b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:51:55 -0800 Subject: [PATCH 076/230] debugging --- src/olmo_core/nn/moe/mlp.py | 5 +---- src/olmo_core/nn/moe/parallel_mlp.py | 6 ++++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index c51bd0921..b4c60f37e 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -132,10 +132,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # Compute the MLP. - x1 = torch.bmm(x, w1) - x2 = torch.bmm(x, w3) - x1 = F.silu(x1) * x2 - return torch.bmm(x1, w2) + return torch.bmm(F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3), w2) class DroplessMoEMLP(MoEMLPBase): diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 385ae54ec..b06f90562 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -171,6 +171,10 @@ def forward_once( with torch.no_grad(): indices, _, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) expert_capacity = self.expert_capacity(top_k, num_items) + print(f"{indices=}") + print(f"{bins=}") + print(f"{batch_size_per_expert=}") + print(f"{expert_capacity=}") x = self.permute_and_compute( x, @@ -326,10 +330,12 @@ def permute_and_compute( # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + print(f"{x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) + print(f"{x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) From 5073d502093cabb5c153d40831edec4e3a6f673c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:53:42 -0800 Subject: [PATCH 077/230] debug --- src/olmo_core/nn/moe/parallel_mlp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index b06f90562..77544233c 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -335,6 +335,9 @@ def permute_and_compute( # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) + print(f"{self.mlp.w1}") + print(f"{self.mlp.w2}") + print(f"{self.mlp.w3}") print(f"{x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. From 242649a845a8cefa3051247a3f8587407249caa2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 13:54:33 -0800 Subject: [PATCH 078/230] fix --- src/olmo_core/nn/moe/parallel_mlp.py | 9 --------- src/test/nn/moe/moe_test.py | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 77544233c..385ae54ec 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -171,10 +171,6 @@ def forward_once( with torch.no_grad(): indices, _, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) expert_capacity = self.expert_capacity(top_k, num_items) - print(f"{indices=}") - print(f"{bins=}") - print(f"{batch_size_per_expert=}") - print(f"{expert_capacity=}") x = self.permute_and_compute( x, @@ -330,15 +326,10 @@ def permute_and_compute( # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - print(f"{x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) - print(f"{self.mlp.w1}") - print(f"{self.mlp.w2}") - print(f"{self.mlp.w3}") - print(f"{x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index fbd7a65f6..6165fb43b 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -24,7 +24,7 @@ def init_mlp_weights(moe: MoEBase): - for w in (moe.experts.mlp.w1, moe.experts.mlp.w2, moe.experts.mlp.w2): + for w in (moe.experts.mlp.w1, moe.experts.mlp.w2, moe.experts.mlp.w3): torch.nn.init.normal_(w, std=0.02) # type: ignore From 0c94a00ff4c28877c57cfbbf4a644545c224d411 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 14:40:05 -0800 Subject: [PATCH 079/230] clean up --- src/olmo_core/nn/moe/mlp.py | 8 ++++---- src/olmo_core/nn/moe/parallel_mlp.py | 20 +++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index b4c60f37e..e75613768 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -43,7 +43,7 @@ def __init__( self.num_experts = num_experts self.gradient_scale: Optional[float] = None - self.experts_per_rank = num_experts + self.num_local_experts = num_experts self.hidden_sharding_degree = 1 def scale_grad(self, w: torch.Tensor) -> torch.Tensor: @@ -63,7 +63,7 @@ def apply_ep(self, ep_mesh: DeviceMesh): f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel shard degree ({num_shards})." ) - self.experts_per_rank = self.num_experts // num_shards + self.num_local_experts = self.num_experts // num_shards self.gradient_scale = 1.0 / num_shards self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) # type: ignore @@ -124,7 +124,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :param x: The input of shape ``(num_local_experts, N, d_model)``. """ # Scale gradients and get local tensors (in case of expert parallelism). - # shape (all): (experts_per_rank, hidden_size, d_model) + # shape (all): (num_local_experts, hidden_size, d_model) w1, w2, w3 = ( get_local_tensor(self.scale_grad(self.w1)), get_local_tensor(self.scale_grad(self.w2)), @@ -214,7 +214,7 @@ def forward(self, x: torch.Tensor, batch_size_per_expert: torch.Tensor) -> torch 1-D ``LongTensor``. """ # Scale gradients and get local tensors (in case of expert parallelism). - # shape (all): (experts_per_rank, hidden_size, d_model) + # shape (all): (num_local_experts, hidden_size, d_model) w1, w2, w3 = ( get_local_tensor(self.scale_grad(self.w1)), get_local_tensor(self.scale_grad(self.w2)), diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 385ae54ec..2a5ab3eb3 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -39,8 +39,8 @@ def num_experts(self) -> int: return self.mlp.num_experts @property - def experts_per_rank(self) -> int: - return self.mlp.experts_per_rank + def num_local_experts(self) -> int: + return self.mlp.num_local_experts @property def hidden_sharding_degree(self) -> int: @@ -152,6 +152,8 @@ def __init__(self, *, mlp: MoEMLP, capacity_factor: float): self.capacity_factor = capacity_factor def expert_capacity(self, top_k: int, num_items: int) -> int: + # TODO: need to ensure this is the same across the process group, could be different w/ + # different batch sizes. items_per_expert = top_k * num_items * self.ep_world_size / self.num_experts return int(self.capacity_factor * items_per_expert) @@ -232,7 +234,7 @@ def parallel_forward_once( # TODO: Fuse this into the prior, local permutation? if self.hidden_sharding_degree > 1: # shape: (num_local_experts, ep_world_size // hidden_sharding_degree, expert_capacity, d_model) - x = x.view(self.experts_per_rank, -1, expert_capacity, self.d_model) + x = x.view(self.num_local_experts, -1, expert_capacity, self.d_model) # shape: (num_experts * hidden_sharding_degree, expert_capacity, d_model) x = x.repeat(1, self.hidden_sharding_degree, 1, 1).view( -1, expert_capacity, self.d_model @@ -264,7 +266,7 @@ def parallel_forward_once( dtype=torch.int32, device=indices.device, ), - self.experts_per_rank, + self.num_local_experts, ) # shape: (num_experts * expert_capacity,) @@ -280,7 +282,7 @@ def parallel_forward_once( # Calculate the bins boundaries from the token counts. # shape: (num_local_experts,) parallel_tokens_per_expert = move_to_device( - torch.tensor([expert_capacity] * self.experts_per_rank), parallel_indices.device + torch.tensor([expert_capacity] * self.num_local_experts), parallel_indices.device ) # shape: (num_local_experts,) parallel_bins = torch.empty_like(parallel_tokens_per_expert, dtype=torch.int32) @@ -425,12 +427,12 @@ def parallel_forward_once( with torch.no_grad(): tpe_handle.wait() - # Reshape to (ep_world_size, experts_per_rank). + # Reshape to (ep_world_size, num_local_experts). repeated_tokens_per_expert = repeated_tokens_per_expert.view( - self.ep_world_size, self.experts_per_rank + self.ep_world_size, self.num_local_experts ) parallel_tokens_per_expert = parallel_tokens_per_expert.view( - self.ep_world_size, self.experts_per_rank + self.ep_world_size, self.num_local_experts ) # NOTE: host-device sync here. @@ -470,7 +472,7 @@ def parallel_forward_once( dtype=torch.int32, device=indices.device, ), - self.experts_per_rank, + self.num_local_experts, ) parallel_top_expert = torch.repeat_interleave( From a6cb7ef87734712922a58cbabfd6ee7bc365be2f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:10:09 -0800 Subject: [PATCH 080/230] debug --- src/olmo_core/nn/moe/parallel_mlp.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 2a5ab3eb3..c0d46c309 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -226,6 +226,8 @@ def parallel_forward_once( # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + if dist.get_rank() == 0: + print(f"A {x=}") # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -291,6 +293,8 @@ def parallel_forward_once( # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. parallel_x_handle.wait() + if dist.get_rank() == 0: + print(f"B {parallel_x=}") parallel_x = self.permute_and_compute( parallel_x, indices=parallel_indices.int(), @@ -299,16 +303,22 @@ def parallel_forward_once( expert_capacity=expert_capacity, top_k=1, ) + if dist.get_rank() == 0: + print(f"C {parallel_x=}") # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) + if dist.get_rank() == 0: + print(f"D {x=}") # Reduce along the hidden sharding to get the final outputs. - # TODO: Fuse this into the following local permutation? - x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) + if self.hidden_sharding_degree > 1: + x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + if dist.get_rank() == 0: + print(f"E {x=}") return x, tokens_per_expert.flatten() From 66f4afd0026039a44b5438a5b8ef017f3e5b3c71 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:17:58 -0800 Subject: [PATCH 081/230] debugging --- src/olmo_core/nn/moe/parallel_mlp.py | 10 ---------- src/test/nn/moe/moe_test.py | 4 ++++ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index c0d46c309..1bc475fd4 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -226,8 +226,6 @@ def parallel_forward_once( # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - if dist.get_rank() == 0: - print(f"A {x=}") # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -293,8 +291,6 @@ def parallel_forward_once( # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. parallel_x_handle.wait() - if dist.get_rank() == 0: - print(f"B {parallel_x=}") parallel_x = self.permute_and_compute( parallel_x, indices=parallel_indices.int(), @@ -303,13 +299,9 @@ def parallel_forward_once( expert_capacity=expert_capacity, top_k=1, ) - if dist.get_rank() == 0: - print(f"C {parallel_x=}") # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) - if dist.get_rank() == 0: - print(f"D {x=}") # Reduce along the hidden sharding to get the final outputs. if self.hidden_sharding_degree > 1: @@ -317,8 +309,6 @@ def parallel_forward_once( # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - if dist.get_rank() == 0: - print(f"E {x=}") return x, tokens_per_expert.flatten() diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 6165fb43b..09e96deb5 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -115,6 +115,9 @@ def run_moe_with_expert_parallelism( output = moe(batch) assert output.shape == batch.shape torch.testing.assert_close(output, expected_output) + if dist.get_rank() == 0: + print(f"{output=}") + print(f"{expected_output=}") losses = moe.compute_losses(total_tokens // ep_mesh.size()) lb_loss = losses["load balancing loss"] @@ -164,6 +167,7 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t assert output.shape == batch.shape assert torch.isfinite(output).all() assert (output > 0).any() + print(f"before dist, expected_output={output}") # Get losses. losses = moe.compute_losses(B * S) From 9582c3cd4eb190e2090e53356b82a46e17ad500e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:27:50 -0800 Subject: [PATCH 082/230] more debug --- src/olmo_core/nn/moe/parallel_mlp.py | 7 ++++++- src/test/nn/moe/moe_test.py | 12 ++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 1bc475fd4..f0d0a2a5a 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -324,18 +324,23 @@ def permute_and_compute( ) -> torch.Tensor: # shape: (N, d_model) x = x.view(-1, x.shape[-1]) + print(f"A, {x=}") # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + print(f"B, {x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) + print(f"C, {x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + x = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + print(f"D, {x=}") + return x class ParallelDroplessMLP(ParallelMLPBase): diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 09e96deb5..c87871962 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -182,9 +182,9 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t loss.backward() assert batch.grad is not None - run_distributed_test( - run_moe_with_expert_parallelism, - backend="nccl", - start_method="spawn", - func_args=(tmp_path, config, d_model, batch.detach().cpu(), output.detach().cpu()), - ) + # run_distributed_test( + # run_moe_with_expert_parallelism, + # backend="nccl", + # start_method="spawn", + # func_args=(tmp_path, config, d_model, batch.detach().cpu(), output.detach().cpu()), + # ) From 461072ad2be1f62895dbc4e87eaa5bcfd953dbf0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:28:31 -0800 Subject: [PATCH 083/230] lol --- src/test/nn/moe/moe_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index c87871962..b9394e70f 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -182,6 +182,7 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t loss.backward() assert batch.grad is not None + assert False # run_distributed_test( # run_moe_with_expert_parallelism, # backend="nccl", From d44000cc0f3da70841a7b612f874e573935ca750 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:39:12 -0800 Subject: [PATCH 084/230] expert indices --- src/olmo_core/nn/moe/parallel_mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index f0d0a2a5a..fd1901315 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -168,6 +168,7 @@ def forward_once( # shape: (N * top_k,) expert_weights = expert_weights.flatten() # shape: (N * top_k,) + print(f"{expert_indices=}") expert_indices = expert_indices.flatten() with torch.no_grad(): From acf03bea933e3878c5596a698ad28483dcbcf2a2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:43:07 -0800 Subject: [PATCH 085/230] try w/ uniform assignment --- src/olmo_core/nn/moe/parallel_mlp.py | 8 +------- src/test/nn/moe/moe_test.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index fd1901315..1bc475fd4 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -168,7 +168,6 @@ def forward_once( # shape: (N * top_k,) expert_weights = expert_weights.flatten() # shape: (N * top_k,) - print(f"{expert_indices=}") expert_indices = expert_indices.flatten() with torch.no_grad(): @@ -325,23 +324,18 @@ def permute_and_compute( ) -> torch.Tensor: # shape: (N, d_model) x = x.view(-1, x.shape[-1]) - print(f"A, {x=}") # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - print(f"B, {x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) - print(f"C, {x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) - x = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - print(f"D, {x=}") - return x + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) class ParallelDroplessMLP(ParallelMLPBase): diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index b9394e70f..b6e178044 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -149,7 +149,11 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t name=moe_type, num_experts=4, hidden_size=256, - router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), + router=MoERouterConfig( + top_k=1, + uniform_expert_assignment=moe_type == MoEType.default, + dtype=DType.from_pt(dtype), + ), z_loss_weight=0.1, dtype=DType.from_pt(dtype), ) @@ -167,7 +171,6 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t assert output.shape == batch.shape assert torch.isfinite(output).all() assert (output > 0).any() - print(f"before dist, expected_output={output}") # Get losses. losses = moe.compute_losses(B * S) @@ -182,10 +185,9 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t loss.backward() assert batch.grad is not None - assert False - # run_distributed_test( - # run_moe_with_expert_parallelism, - # backend="nccl", - # start_method="spawn", - # func_args=(tmp_path, config, d_model, batch.detach().cpu(), output.detach().cpu()), - # ) + run_distributed_test( + run_moe_with_expert_parallelism, + backend="nccl", + start_method="spawn", + func_args=(tmp_path, config, d_model, batch.detach().cpu(), output.detach().cpu()), + ) From 1ee38fab8a6ead306af59243caa0b25e1a66ec38 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:44:01 -0800 Subject: [PATCH 086/230] fix --- src/test/nn/moe/moe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index b6e178044..536d72064 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -114,10 +114,10 @@ def run_moe_with_expert_parallelism( # Run forward pass. output = moe(batch) assert output.shape == batch.shape - torch.testing.assert_close(output, expected_output) if dist.get_rank() == 0: print(f"{output=}") print(f"{expected_output=}") + torch.testing.assert_close(output, expected_output) losses = moe.compute_losses(total_tokens // ep_mesh.size()) lb_loss = losses["load balancing loss"] From fd945a5807fdd3be637f93ac9c7e511aa51f9472 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:49:14 -0800 Subject: [PATCH 087/230] debug --- src/olmo_core/nn/moe/parallel_mlp.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 1bc475fd4..311007807 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -211,6 +211,8 @@ def parallel_forward_once( # output. # shape: (N, d_model) x = x.view(-1, x.shape[-1]) + if dist.get_rank() == 0: + print(f"A {x=}") num_items, top_k = expert_weights.shape @@ -226,6 +228,8 @@ def parallel_forward_once( # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + if dist.get_rank() == 0: + print(f"B {x=}") # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -291,6 +295,8 @@ def parallel_forward_once( # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. parallel_x_handle.wait() + if dist.get_rank() == 0: + print(f"C {parallel_x=}") parallel_x = self.permute_and_compute( parallel_x, indices=parallel_indices.int(), @@ -302,6 +308,8 @@ def parallel_forward_once( # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) + if dist.get_rank() == 0: + print(f"D {x=}") # Reduce along the hidden sharding to get the final outputs. if self.hidden_sharding_degree > 1: @@ -309,6 +317,8 @@ def parallel_forward_once( # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + if dist.get_rank() == 0: + print(f"E {x=}") return x, tokens_per_expert.flatten() From e13fedc92169a9bf549312b2a0c55541e5b52126 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 15:52:01 -0800 Subject: [PATCH 088/230] small bz --- src/test/nn/moe/moe_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 536d72064..4fe930baa 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -144,7 +144,7 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t device = torch.device("cuda") - d_model = 128 + d_model = 8 config = MoEConfig( name=moe_type, num_experts=4, @@ -165,7 +165,7 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t save_model_and_optim_state(tmp_path, moe) # Create batch and run forward pass. - B, S = 4, 16 + B, S = 2, 4 batch = torch.randn(B, S, d_model, dtype=dtype, device=device, requires_grad=True) output = moe(batch) assert output.shape == batch.shape From 87afd270ba641f0008eee57cb5df22ab64112920 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 16:01:37 -0800 Subject: [PATCH 089/230] more --- src/olmo_core/nn/moe/parallel_mlp.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 311007807..637c14fc1 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -224,6 +224,7 @@ def parallel_forward_once( with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) expert_capacity = self.expert_capacity(top_k, num_items) + print(f"{expert_capacity=}") # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, expert_capacity, d_model) @@ -304,12 +305,15 @@ def parallel_forward_once( bins=parallel_bins, expert_capacity=expert_capacity, top_k=1, + debug=dist.get_rank() == 0, ) + if dist.get_rank() == 0: + print(f"D {parallel_x=}") # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) if dist.get_rank() == 0: - print(f"D {x=}") + print(f"E {x=}") # Reduce along the hidden sharding to get the final outputs. if self.hidden_sharding_degree > 1: @@ -331,21 +335,31 @@ def permute_and_compute( bins: torch.Tensor, expert_capacity: int, top_k: int, + debug: bool = False, ) -> torch.Tensor: # shape: (N, d_model) x = x.view(-1, x.shape[-1]) + if debug: + print(f"C.1 {x=}") # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + if debug: + print(f"C.2 {x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) + if debug: + print(f"C.3 {x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + x = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + if debug: + print(f"C.4 {x=}") + return x class ParallelDroplessMLP(ParallelMLPBase): From 2ceba527d75054de65498b9d26b1862e4ae1f484 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 16:05:23 -0800 Subject: [PATCH 090/230] more --- src/olmo_core/nn/moe/parallel_mlp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 637c14fc1..3ab84fbaa 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -212,7 +212,7 @@ def parallel_forward_once( # shape: (N, d_model) x = x.view(-1, x.shape[-1]) if dist.get_rank() == 0: - print(f"A {x=}") + print(f"A {x.shape=} {x=}") num_items, top_k = expert_weights.shape @@ -230,7 +230,7 @@ def parallel_forward_once( # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) if dist.get_rank() == 0: - print(f"B {x=}") + print(f"B {x.shape=} {x=}") # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -297,7 +297,7 @@ def parallel_forward_once( # Block to make sure that the cross-device permutation is complete. parallel_x_handle.wait() if dist.get_rank() == 0: - print(f"C {parallel_x=}") + print(f"C {parallel_x.shape=} {parallel_x=}") parallel_x = self.permute_and_compute( parallel_x, indices=parallel_indices.int(), @@ -313,7 +313,7 @@ def parallel_forward_once( # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) if dist.get_rank() == 0: - print(f"E {x=}") + print(f"E {x.shape=} {x=}") # Reduce along the hidden sharding to get the final outputs. if self.hidden_sharding_degree > 1: @@ -322,7 +322,7 @@ def parallel_forward_once( # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) if dist.get_rank() == 0: - print(f"E {x=}") + print(f"E {x.shape=} {x=}") return x, tokens_per_expert.flatten() From 7db701824af89ea9cdf87ac9a71999014a2d4ffe Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 16:08:58 -0800 Subject: [PATCH 091/230] more debug --- src/olmo_core/nn/moe/parallel_mlp.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 3ab84fbaa..6d0bf6944 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -224,7 +224,8 @@ def parallel_forward_once( with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) expert_capacity = self.expert_capacity(top_k, num_items) - print(f"{expert_capacity=}") + if dist.get_rank() == 0: + print(f"{expert_capacity=}") # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, expert_capacity, d_model) @@ -308,7 +309,7 @@ def parallel_forward_once( debug=dist.get_rank() == 0, ) if dist.get_rank() == 0: - print(f"D {parallel_x=}") + print(f"D {parallel_x.shape=} {parallel_x=}") # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) @@ -340,25 +341,25 @@ def permute_and_compute( # shape: (N, d_model) x = x.view(-1, x.shape[-1]) if debug: - print(f"C.1 {x=}") + print(f"C.1 {x.shape=} {x=}") # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) if debug: - print(f"C.2 {x=}") + print(f"C.2 {x.shape=} {x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) if debug: - print(f"C.3 {x=}") + print(f"C.3 {x.shape=} {x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) x = ops.binned_scatter(x, indices, expert_weights, bins, top_k) if debug: - print(f"C.4 {x=}") + print(f"C.4 {x.shape=} {x=}") return x From 157b38372521bc255b7b4345455ec9d4a5ce4371 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 16:22:31 -0800 Subject: [PATCH 092/230] try this --- src/olmo_core/nn/moe/parallel_mlp.py | 41 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 6d0bf6944..fc56601f0 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -153,9 +153,11 @@ def __init__(self, *, mlp: MoEMLP, capacity_factor: float): def expert_capacity(self, top_k: int, num_items: int) -> int: # TODO: need to ensure this is the same across the process group, could be different w/ - # different batch sizes. - items_per_expert = top_k * num_items * self.ep_world_size / self.num_experts - return int(self.capacity_factor * items_per_expert) + # different local batch sizes. + num_global_items = num_items * self.ep_world_size + num_global_expert_inputs = top_k * num_global_items + inputs_per_expert = num_global_expert_inputs / self.num_experts + return int(self.capacity_factor * inputs_per_expert) def forward_once( self, @@ -224,12 +226,13 @@ def parallel_forward_once( with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) expert_capacity = self.expert_capacity(top_k, num_items) + local_expert_capacity = expert_capacity // self.ep_world_size if dist.get_rank() == 0: - print(f"{expert_capacity=}") + print(f"{expert_capacity=}, {local_expert_capacity=}") # Permute locally so that the tokens for each device are stored contiguously. - # shape: (num_experts, expert_capacity, d_model) - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # shape: (num_experts, local_expert_capacity, d_model) + x = ops.binned_gather(x, indices, bins, local_expert_capacity, top_k) if dist.get_rank() == 0: print(f"B {x.shape=} {x=}") @@ -237,19 +240,18 @@ def parallel_forward_once( # multiple devices own parts of the same sets of experts. # Replicate the token counts so devices that share experts # get all of the tokens assigned to them. - # TODO: Fuse this into the prior, local permutation? if self.hidden_sharding_degree > 1: - # shape: (num_local_experts, ep_world_size // hidden_sharding_degree, expert_capacity, d_model) - x = x.view(self.num_local_experts, -1, expert_capacity, self.d_model) - # shape: (num_experts * hidden_sharding_degree, expert_capacity, d_model) + # shape: (num_local_experts, ep_world_size // hidden_sharding_degree, local_expert_capacity, d_model) + x = x.view(self.num_local_experts, -1, local_expert_capacity, self.d_model) + # shape: (num_experts * hidden_sharding_degree, local_expert_capacity, d_model) x = x.repeat(1, self.hidden_sharding_degree, 1, 1).view( - -1, expert_capacity, self.d_model + -1, local_expert_capacity, self.d_model ) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. - # shape: (num_local_experts * ep_world_size, expert_capacity, d_model) - # = (num_experts, expert_capacity, d_model) + # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) + # = (num_experts, local_expert_capacity, d_model) parallel_x, parallel_x_handle = ops.all_to_all( x, group=self._ep_pg, @@ -275,20 +277,21 @@ def parallel_forward_once( self.num_local_experts, ) - # shape: (num_experts * expert_capacity,) + # shape: (num_experts * local_expert_capacity,) parallel_top_expert = torch.repeat_interleave( parallel_top_expert, - expert_capacity, - output_size=parallel_top_expert.numel() * expert_capacity, + local_expert_capacity, + output_size=parallel_top_expert.numel() * local_expert_capacity, ) - # shape: (num_experts * expert_capacity,) + # shape: (num_experts * local_expert_capacity,) _, parallel_indices = torch.sort(parallel_top_expert) # Calculate the bins boundaries from the token counts. # shape: (num_local_experts,) parallel_tokens_per_expert = move_to_device( - torch.tensor([expert_capacity] * self.num_local_experts), parallel_indices.device + torch.tensor([local_expert_capacity] * self.num_local_experts), + parallel_indices.device, ) # shape: (num_local_experts,) parallel_bins = torch.empty_like(parallel_tokens_per_expert, dtype=torch.int32) @@ -469,7 +472,6 @@ def parallel_forward_once( # multiple devices own parts of the same sets of experts. # Replicate the token counts so devices that share experts # get all of the tokens assigned to them. - # TODO: Fuse this into the prior, local permutation? x = ops.repeat(x, (self.hidden_sharding_degree, 1)) # Start the cross-device permutation asynchronously so we can @@ -538,7 +540,6 @@ def parallel_forward_once( ) # Reduce along the hidden sharding to get the final outputs. - # TODO: Fuse this into the following local permutation? x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. From b1116597ed4a625d9b24703845822ede9ffd98a2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 16:26:28 -0800 Subject: [PATCH 093/230] try this --- src/olmo_core/nn/moe/parallel_mlp.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index fc56601f0..543c80191 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -250,8 +250,8 @@ def parallel_forward_once( # Start the cross-device permutation asynchronously so we can # overlap communication with computation. - # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) - # = (num_experts, local_expert_capacity, d_model) + # shape: (num_local_experts * ep_world_size, expert_capacity, d_model) + # = (num_experts, expert_capacity, d_model) parallel_x, parallel_x_handle = ops.all_to_all( x, group=self._ep_pg, @@ -277,20 +277,20 @@ def parallel_forward_once( self.num_local_experts, ) - # shape: (num_experts * local_expert_capacity,) + # shape: (num_experts * expert_capacity,) parallel_top_expert = torch.repeat_interleave( parallel_top_expert, - local_expert_capacity, - output_size=parallel_top_expert.numel() * local_expert_capacity, + expert_capacity, + output_size=parallel_top_expert.numel() * expert_capacity, ) - # shape: (num_experts * local_expert_capacity,) + # shape: (num_experts * expert_capacity,) _, parallel_indices = torch.sort(parallel_top_expert) # Calculate the bins boundaries from the token counts. # shape: (num_local_experts,) parallel_tokens_per_expert = move_to_device( - torch.tensor([local_expert_capacity] * self.num_local_experts), + torch.tensor([expert_capacity] * self.num_local_experts), parallel_indices.device, ) # shape: (num_local_experts,) From 0cc1c7891b7d5d814ea0133a9bc19bd2ca46a210 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 16:45:19 -0800 Subject: [PATCH 094/230] try this --- src/olmo_core/nn/moe/parallel_mlp.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 543c80191..283491f9f 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -250,8 +250,7 @@ def parallel_forward_once( # Start the cross-device permutation asynchronously so we can # overlap communication with computation. - # shape: (num_local_experts * ep_world_size, expert_capacity, d_model) - # = (num_experts, expert_capacity, d_model) + # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) parallel_x, parallel_x_handle = ops.all_to_all( x, group=self._ep_pg, @@ -277,14 +276,15 @@ def parallel_forward_once( self.num_local_experts, ) - # shape: (num_experts * expert_capacity,) + # shape: (num_local_experts * ep_world_size * local_expert_capacity,) + # = (num_local_experts * expert_capacity,) parallel_top_expert = torch.repeat_interleave( parallel_top_expert, - expert_capacity, - output_size=parallel_top_expert.numel() * expert_capacity, + local_expert_capacity, + output_size=parallel_top_expert.numel() * local_expert_capacity, ) - # shape: (num_experts * expert_capacity,) + # shape: (num_local_experts * expert_capacity,) _, parallel_indices = torch.sort(parallel_top_expert) # Calculate the bins boundaries from the token counts. @@ -302,10 +302,12 @@ def parallel_forward_once( parallel_x_handle.wait() if dist.get_rank() == 0: print(f"C {parallel_x.shape=} {parallel_x=}") + # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) + # ~= (num_local_experts, expert_capacity, d_model) parallel_x = self.permute_and_compute( parallel_x, indices=parallel_indices.int(), - expert_weights=None, # expert_weights + expert_weights=None, bins=parallel_bins, expert_capacity=expert_capacity, top_k=1, From fea70b7c426bf2945f171f595538c7924d4d78b3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 17:13:23 -0800 Subject: [PATCH 095/230] cache --- src/olmo_core/nn/buffer_cache.py | 12 +++ src/olmo_core/nn/moe/moe.py | 17 ++- src/olmo_core/nn/moe/parallel_mlp.py | 147 +++++++++++++------------- src/olmo_core/nn/transformer/block.py | 2 +- src/test/nn/moe/moe_test.py | 6 +- 5 files changed, 105 insertions(+), 79 deletions(-) diff --git a/src/olmo_core/nn/buffer_cache.py b/src/olmo_core/nn/buffer_cache.py index 918e32253..8d3402747 100644 --- a/src/olmo_core/nn/buffer_cache.py +++ b/src/olmo_core/nn/buffer_cache.py @@ -1,7 +1,10 @@ from collections.abc import MutableMapping +from typing import Optional import torch +from olmo_core.utils import move_to_device + class BufferCache(dict, MutableMapping[str, torch.Tensor]): """ @@ -12,3 +15,12 @@ class BufferCache(dict, MutableMapping[str, torch.Tensor]): since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into NaNs when they're synchronized due to casting or some other issue. """ + + def get_for_device(self, key: str, device: torch.device) -> Optional[torch.Tensor]: + if (tensor := self.get(key)) is not None: + if tensor.device != device: + tensor = move_to_device(tensor, device) + self[key] = tensor + return tensor + else: + return None diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index f4704ea85..6dc2cbe84 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -12,6 +12,7 @@ from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel from olmo_core.exceptions import OLMoConfigurationError +from ..buffer_cache import BufferCache from .loss import MoELoadBalancingLoss, MoELoss, MoERouterZLoss from .mlp import DroplessMoEMLP, MoEMLP from .parallel_mlp import ParallelDroplessMLP, ParallelMLP, ParallelMLPBase @@ -67,7 +68,14 @@ def num_active_params(self, d_model: int) -> int: + (3 * d_model * self.hidden_size * self.router.top_k) ) - def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> "MoEBase": + def build( + self, + d_model: int, + *, + num_layers: int, + init_device: str = "cpu", + cache: Optional[BufferCache] = None, + ) -> "MoEBase": kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") kwargs.update( @@ -75,6 +83,7 @@ def build(self, d_model: int, *, num_layers: int, init_device: str = "cpu") -> " num_layers=num_layers, init_device=init_device, dtype=kwargs.pop("dtype").as_pt(), + cache=cache, ) try: @@ -108,6 +117,7 @@ def __init__( lb_loss_weight: Optional[float] = None, z_loss_weight: Optional[float] = None, dtype: torch.dtype = torch.float32, + cache: Optional[BufferCache] = None, **kwargs, ): super().__init__() @@ -118,6 +128,7 @@ def __init__( hidden_size=hidden_size, dtype=dtype, init_device=init_device, + cache=cache, **kwargs, ) self.shared_experts = ( @@ -261,6 +272,7 @@ def _init_parallel_mlp( # type: ignore[override] capacity_factor: float, dtype: torch.dtype = torch.float32, init_device: str = "cpu", + cache: Optional[BufferCache] = None, ) -> ParallelMLP: return ParallelMLP( mlp=MoEMLP( @@ -271,6 +283,7 @@ def _init_parallel_mlp( # type: ignore[override] init_device=init_device, ), capacity_factor=capacity_factor, + cache=cache, ) @@ -287,6 +300,7 @@ def _init_parallel_mlp( # type: ignore[override] hidden_size: int, dtype: torch.dtype = torch.float32, init_device: str = "cpu", + cache: Optional[BufferCache] = None, ) -> ParallelDroplessMLP: return ParallelDroplessMLP( mlp=DroplessMoEMLP( @@ -296,4 +310,5 @@ def _init_parallel_mlp( # type: ignore[override] dtype=dtype, init_device=init_device, ), + cache=cache, ) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 283491f9f..e0b1590d2 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -12,6 +12,7 @@ from olmo_core.distributed.utils import get_world_size from olmo_core.utils import move_to_device +from ..buffer_cache import BufferCache from . import ops from .mlp import DroplessMoEMLP, MoEMLP, MoEMLPBase @@ -23,9 +24,10 @@ class ParallelMLPBase(nn.Module): Wraps an MoE MLP layer to coordinate the routing and expert parallelism. """ - def __init__(self, *, mlp: MoEMLPBase): + def __init__(self, *, mlp: MoEMLPBase, cache: Optional[BufferCache] = None): super().__init__() self.mlp = mlp + self._cache = cache or BufferCache() self._expert_parallel_enabled: bool = False self._ep_mesh: Optional[DeviceMesh] = None self._ep_pg: Optional[dist.ProcessGroup] = None @@ -147,18 +149,70 @@ def parallel_forward_once( class ParallelMLP(ParallelMLPBase): - def __init__(self, *, mlp: MoEMLP, capacity_factor: float): - super().__init__(mlp=mlp) + def __init__(self, *, mlp: MoEMLP, capacity_factor: float, cache: Optional[BufferCache] = None): + super().__init__(mlp=mlp, cache=cache) self.capacity_factor = capacity_factor def expert_capacity(self, top_k: int, num_items: int) -> int: - # TODO: need to ensure this is the same across the process group, could be different w/ - # different local batch sizes. + # TODO: need to ensure this is the same across the process group. + # If local batch sizes are different then these will be different, and `parallel_forward_once` + # will break. This shouldn't be a problem with our trainer, but would be an issue for inference. num_global_items = num_items * self.ep_world_size num_global_expert_inputs = top_k * num_global_items inputs_per_expert = num_global_expert_inputs / self.num_experts return int(self.capacity_factor * inputs_per_expert) + @torch.no_grad() + def _get_parallel_indices_and_bins( + self, *, expert_capacity: int, local_expert_capacity: int, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + indices_cache_key = f"moe_par_expert_indices_{expert_capacity}_{local_expert_capacity}" + bins_cache_key = f"moe_par_expert_bins_{expert_capacity}_{local_expert_capacity}" + + if ( + parallel_indices := self._cache.get_for_device(indices_cache_key, device) + ) is not None and ( + parallel_bins := self._cache.get_for_device(bins_cache_key, device) + ) is not None: + return parallel_indices, parallel_bins + + # Construct the expert indices for the permuted tokens. + # shape: (num_experts,) = (num_local_experts * ep_world_size,) + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=device, + ), + self.num_local_experts, + ) + + # shape: (num_local_experts * ep_world_size * local_expert_capacity,) + # = (num_local_experts * expert_capacity,) + parallel_top_expert = torch.repeat_interleave( + parallel_top_expert, + local_expert_capacity, + output_size=parallel_top_expert.numel() * local_expert_capacity, + ) + + # shape: (num_local_experts * expert_capacity,) + _, parallel_indices = torch.sort(parallel_top_expert) + + # Calculate the bins boundaries from the token counts. + # shape: (num_local_experts,) + parallel_tokens_per_expert = move_to_device( + torch.tensor([expert_capacity] * self.num_local_experts), + parallel_indices.device, + ) + # shape: (num_local_experts,) + parallel_bins = torch.empty_like(parallel_tokens_per_expert, dtype=torch.int32) + torch.cumsum(parallel_tokens_per_expert, 0, out=parallel_bins) + + self._cache[indices_cache_key] = parallel_indices + self._cache[bins_cache_key] = parallel_bins + + return parallel_indices, parallel_bins + def forward_once( self, x: torch.Tensor, @@ -166,6 +220,7 @@ def forward_once( expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: num_items, top_k = expert_weights.shape + expert_capacity = self.expert_capacity(top_k, num_items) # shape: (N * top_k,) expert_weights = expert_weights.flatten() @@ -174,7 +229,6 @@ def forward_once( with torch.no_grad(): indices, _, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) - expert_capacity = self.expert_capacity(top_k, num_items) x = self.permute_and_compute( x, @@ -213,10 +267,10 @@ def parallel_forward_once( # output. # shape: (N, d_model) x = x.view(-1, x.shape[-1]) - if dist.get_rank() == 0: - print(f"A {x.shape=} {x=}") num_items, top_k = expert_weights.shape + expert_capacity = self.expert_capacity(top_k, num_items) + local_expert_capacity = expert_capacity // self.ep_world_size # shape: (N * top_k,) expert_weights = expert_weights.flatten() @@ -225,16 +279,10 @@ def parallel_forward_once( with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) - expert_capacity = self.expert_capacity(top_k, num_items) - local_expert_capacity = expert_capacity // self.ep_world_size - if dist.get_rank() == 0: - print(f"{expert_capacity=}, {local_expert_capacity=}") # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, local_expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, local_expert_capacity, top_k) - if dist.get_rank() == 0: - print(f"B {x.shape=} {x=}") # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -257,51 +305,20 @@ def parallel_forward_once( async_op=True, ) - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - - # Construct the expert indices for the permuted tokens. - # shape: (num_experts,) = (num_local_experts * ep_world_size,) - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * self.hidden_sharding_degree, - dtype=torch.int32, - device=indices.device, - ), - self.num_local_experts, - ) - - # shape: (num_local_experts * ep_world_size * local_expert_capacity,) - # = (num_local_experts * expert_capacity,) - parallel_top_expert = torch.repeat_interleave( - parallel_top_expert, - local_expert_capacity, - output_size=parallel_top_expert.numel() * local_expert_capacity, - ) - - # shape: (num_local_experts * expert_capacity,) - _, parallel_indices = torch.sort(parallel_top_expert) - - # Calculate the bins boundaries from the token counts. - # shape: (num_local_experts,) - parallel_tokens_per_expert = move_to_device( - torch.tensor([expert_capacity] * self.num_local_experts), - parallel_indices.device, - ) - # shape: (num_local_experts,) - parallel_bins = torch.empty_like(parallel_tokens_per_expert, dtype=torch.int32) - torch.cumsum(parallel_tokens_per_expert, 0, out=parallel_bins) + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. + # shape (both): (num_local_experts,) + parallel_indices, parallel_bins = self._get_parallel_indices_and_bins( + expert_capacity=expert_capacity, + local_expert_capacity=local_expert_capacity, + device=indices.device, + ) # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. parallel_x_handle.wait() - if dist.get_rank() == 0: - print(f"C {parallel_x.shape=} {parallel_x=}") # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) # ~= (num_local_experts, expert_capacity, d_model) parallel_x = self.permute_and_compute( @@ -311,15 +328,10 @@ def parallel_forward_once( bins=parallel_bins, expert_capacity=expert_capacity, top_k=1, - debug=dist.get_rank() == 0, ) - if dist.get_rank() == 0: - print(f"D {parallel_x.shape=} {parallel_x=}") # Un-permute the tokens across the devices. x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) - if dist.get_rank() == 0: - print(f"E {x.shape=} {x=}") # Reduce along the hidden sharding to get the final outputs. if self.hidden_sharding_degree > 1: @@ -327,8 +339,6 @@ def parallel_forward_once( # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - if dist.get_rank() == 0: - print(f"E {x.shape=} {x=}") return x, tokens_per_expert.flatten() @@ -341,30 +351,21 @@ def permute_and_compute( bins: torch.Tensor, expert_capacity: int, top_k: int, - debug: bool = False, ) -> torch.Tensor: # shape: (N, d_model) x = x.view(-1, x.shape[-1]) - if debug: - print(f"C.1 {x.shape=} {x=}") # Route the tokens for MoE computation. # shape: (num_experts, expert_capacity, d_model) x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - if debug: - print(f"C.2 {x.shape=} {x=}") # Perform the expert computation. # shape: (num_experts, expert_capacity, d_model) x = self.mlp(x) - if debug: - print(f"C.3 {x.shape=} {x=}") # Un-route the data for the MoE output. Items that were dropped will be zeroed out. # shape: (N, d_model) x = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - if debug: - print(f"C.4 {x.shape=} {x=}") return x @@ -376,8 +377,8 @@ class ParallelDroplessMLP(ParallelMLPBase): When expert parallelism is enabled the forward pass involves a host-device sync. """ - def __init__(self, *, mlp: DroplessMoEMLP): - super().__init__(mlp=mlp) + def __init__(self, *, mlp: DroplessMoEMLP, cache: Optional[BufferCache] = None): + super().__init__(mlp=mlp, cache=cache) def forward_once( self, diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index e06cec995..016cff145 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -370,7 +370,7 @@ def __init__( self.attention = attention.build(d_model, init_device=init_device, cache=cache) self.attention_norm = layer_norm.build(d_model, init_device=init_device) self.feed_forward_moe = feed_forward_moe.build( - d_model=d_model, num_layers=num_blocks, init_device=init_device + d_model=d_model, num_layers=num_blocks, init_device=init_device, cache=cache ) self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 4fe930baa..dfde008b3 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -114,9 +114,6 @@ def run_moe_with_expert_parallelism( # Run forward pass. output = moe(batch) assert output.shape == batch.shape - if dist.get_rank() == 0: - print(f"{output=}") - print(f"{expected_output=}") torch.testing.assert_close(output, expected_output) losses = moe.compute_losses(total_tokens // ep_mesh.size()) @@ -151,7 +148,8 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t hidden_size=256, router=MoERouterConfig( top_k=1, - uniform_expert_assignment=moe_type == MoEType.default, + uniform_expert_assignment=moe_type + == MoEType.default, # EP results may be different otherwise dtype=DType.from_pt(dtype), ), z_loss_weight=0.1, From 06784d039d7eef3e56e7fd8c5b7636d21b9a0599 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 17:14:55 -0800 Subject: [PATCH 096/230] fix --- src/olmo_core/nn/moe/moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 6dc2cbe84..6046f0a68 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -248,6 +248,7 @@ def __init__( lb_loss_weight: Optional[float] = None, z_loss_weight: Optional[float] = None, dtype: torch.dtype = torch.float32, + cache: Optional[BufferCache] = None, ): super().__init__( d_model=d_model, @@ -261,6 +262,7 @@ def __init__( z_loss_weight=z_loss_weight, dtype=dtype, capacity_factor=capacity_factor, + cache=cache, ) def _init_parallel_mlp( # type: ignore[override] From 9a2cdadaddebaa6759dbd06fbd559f21fe36e0e7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 17:42:50 -0800 Subject: [PATCH 097/230] cache --- src/olmo_core/nn/moe/moe.py | 5 + src/olmo_core/nn/moe/parallel_mlp.py | 103 ++++++++++++------ src/olmo_core/nn/transformer/model.py | 10 +- .../train/train_module/transformer.py | 6 +- 4 files changed, 89 insertions(+), 35 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 6046f0a68..08b485c86 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -154,6 +154,9 @@ def __init__( ) ) + def warmup_cache(self, max_local_microbatch_size: int): + self.experts.warmup_cache(max_local_microbatch_size) + def compute_losses( self, total_bz: Union[int, torch.Tensor], reset: bool = True ) -> Dict[str, torch.Tensor]: @@ -284,6 +287,7 @@ def _init_parallel_mlp( # type: ignore[override] dtype=dtype, init_device=init_device, ), + top_k=self.router.top_k, capacity_factor=capacity_factor, cache=cache, ) @@ -312,5 +316,6 @@ def _init_parallel_mlp( # type: ignore[override] dtype=dtype, init_device=init_device, ), + top_k=self.router.top_k, cache=cache, ) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index e0b1590d2..315040a6e 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -10,7 +10,7 @@ from torch.distributed import DeviceMesh from olmo_core.distributed.utils import get_world_size -from olmo_core.utils import move_to_device +from olmo_core.utils import get_default_device, move_to_device from ..buffer_cache import BufferCache from . import ops @@ -24,14 +24,18 @@ class ParallelMLPBase(nn.Module): Wraps an MoE MLP layer to coordinate the routing and expert parallelism. """ - def __init__(self, *, mlp: MoEMLPBase, cache: Optional[BufferCache] = None): + def __init__(self, *, mlp: MoEMLPBase, top_k: int, cache: Optional[BufferCache] = None): super().__init__() self.mlp = mlp + self.top_k = top_k self._cache = cache or BufferCache() self._expert_parallel_enabled: bool = False self._ep_mesh: Optional[DeviceMesh] = None self._ep_pg: Optional[dist.ProcessGroup] = None + def warmup_cache(self, max_local_microbatch_size: int): + del max_local_microbatch_size + @property def d_model(self) -> int: return self.mlp.d_model @@ -149,17 +153,56 @@ def parallel_forward_once( class ParallelMLP(ParallelMLPBase): - def __init__(self, *, mlp: MoEMLP, capacity_factor: float, cache: Optional[BufferCache] = None): - super().__init__(mlp=mlp, cache=cache) + def __init__( + self, + *, + mlp: MoEMLP, + top_k: int, + capacity_factor: float, + cache: Optional[BufferCache] = None, + max_local_microbatch_size: Optional[int] = None, + ): + super().__init__(mlp=mlp, top_k=top_k, cache=cache) self.capacity_factor = capacity_factor + self.max_local_microbatch_size = max_local_microbatch_size + if self.max_local_microbatch_size is not None: + self.warmup_cache(self.max_local_microbatch_size) + + def warmup_cache(self, max_local_microbatch_size: int): + self.max_local_microbatch_size = max_local_microbatch_size + # TODO: call `_get_parallel_indices_and_bins()` up-front to warm the cache so + # torch.compile() doesn't try to trace that. + expert_capacity = self.expert_capacity(self.max_local_microbatch_size) + local_expert_capacity = expert_capacity // self.ep_world_size + self._get_parallel_indices_and_bins( + expert_capacity=expert_capacity, + local_expert_capacity=local_expert_capacity, + device=get_default_device(), + ) + + def apply_ep(self, ep_mesh: DeviceMesh): + super().apply_ep(ep_mesh) + if self.max_local_microbatch_size is not None: + self.warmup_cache(self.max_local_microbatch_size) - def expert_capacity(self, top_k: int, num_items: int) -> int: - # TODO: need to ensure this is the same across the process group. + def expert_capacity(self, local_batch_size: int) -> int: + # NOTE: need to ensure this is the same across the process group. # If local batch sizes are different then these will be different, and `parallel_forward_once` # will break. This shouldn't be a problem with our trainer, but would be an issue for inference. - num_global_items = num_items * self.ep_world_size - num_global_expert_inputs = top_k * num_global_items + # To avoid that you could set `self.max_local_microbatch_size` up-front. + if self.max_local_microbatch_size is not None: + if local_batch_size > self.max_local_microbatch_size: + raise RuntimeError( + f"Local batch size ({local_batch_size:,d}) bigger than " + f"configured max local batch size ({self.max_local_microbatch_size:,d})" + ) + else: + local_batch_size = self.max_local_microbatch_size + + num_global_items = local_batch_size * self.ep_world_size + num_global_expert_inputs = self.top_k * num_global_items inputs_per_expert = num_global_expert_inputs / self.num_experts + return int(self.capacity_factor * inputs_per_expert) @torch.no_grad() @@ -219,12 +262,12 @@ def forward_once( expert_weights: torch.Tensor, expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - num_items, top_k = expert_weights.shape - expert_capacity = self.expert_capacity(top_k, num_items) + batch_size, _ = expert_weights.shape + expert_capacity = self.expert_capacity(batch_size) - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_weights = expert_weights.flatten() - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_indices = expert_indices.flatten() with torch.no_grad(): @@ -236,7 +279,7 @@ def forward_once( expert_weights=expert_weights, bins=bins, expert_capacity=expert_capacity, - top_k=top_k, + top_k=self.top_k, ) return x, batch_size_per_expert @@ -268,13 +311,13 @@ def parallel_forward_once( # shape: (N, d_model) x = x.view(-1, x.shape[-1]) - num_items, top_k = expert_weights.shape - expert_capacity = self.expert_capacity(top_k, num_items) + num_items, _ = expert_weights.shape + expert_capacity = self.expert_capacity(num_items) local_expert_capacity = expert_capacity // self.ep_world_size - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_weights = expert_weights.flatten() - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_indices = expert_indices.flatten() with torch.no_grad(): @@ -282,7 +325,7 @@ def parallel_forward_once( # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, local_expert_capacity, d_model) - x = ops.binned_gather(x, indices, bins, local_expert_capacity, top_k) + x = ops.binned_gather(x, indices, bins, local_expert_capacity, self.top_k) # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -338,7 +381,7 @@ def parallel_forward_once( x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() @@ -377,8 +420,8 @@ class ParallelDroplessMLP(ParallelMLPBase): When expert parallelism is enabled the forward pass involves a host-device sync. """ - def __init__(self, *, mlp: DroplessMoEMLP, cache: Optional[BufferCache] = None): - super().__init__(mlp=mlp, cache=cache) + def __init__(self, *, mlp: DroplessMoEMLP, top_k: int, cache: Optional[BufferCache] = None): + super().__init__(mlp=mlp, top_k=top_k, cache=cache) def forward_once( self, @@ -386,11 +429,9 @@ def forward_once( expert_weights: torch.Tensor, expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - top_k = expert_weights.shape[-1] - - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_weights = expert_weights.flatten() - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_indices = expert_indices.flatten() with torch.no_grad(): @@ -403,7 +444,7 @@ def forward_once( bin_ids=bin_ids, expert_weights=expert_weights, bins=bins, - top_k=top_k, + top_k=self.top_k, ) return out, batch_size_per_expert @@ -418,11 +459,9 @@ def parallel_forward_once( # but with extra bookkeeping to manage the dynamic sizes, and unfortunately this introduces # a host-device sync. - top_k = expert_weights.shape[-1] - - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_weights = expert_weights.flatten() - # shape: (N * top_k,) + # shape: (batch_size * top_k,) expert_indices = expert_indices.flatten() with torch.no_grad(): @@ -451,7 +490,7 @@ def parallel_forward_once( # Permute locally and without any padding so that tokens for each # parallel device are stored contiguously. - x = ops.gather(x.view(-1, x.shape[-1]), indices, bin_ids, bins, top_k) + x = ops.gather(x.view(-1, x.shape[-1]), indices, bin_ids, bins, self.top_k) # Compute the number of tokens that will be received from each # device and permute the input data across the devices. @@ -546,7 +585,7 @@ def parallel_forward_once( x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 935238d1e..f586a481e 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -155,13 +155,16 @@ def init_weights( self, *, max_seq_len: Optional[int] = None, + max_local_microbatch_size: Optional[int] = None, device: Optional[torch.device] = None, ) -> torch.Generator: """ Initialize the model weights. - :param max_seq_len: The maximum sequence length expected during training. This is used + :param max_seq_len: The maximum sequence length expected. This is used to warm up the RoPE cache. + :param max_local_microbatch_size: The maximum local (rank) micro-batch size (in tokens) + expected. This is used to warm-up some MoE cache. :param device: The device the local copy of the model will be trained on. """ device = device or self.device @@ -203,8 +206,11 @@ def init_weights( generator=generator, ) else: + block = cast(MoETransformerBlock, block) + if max_local_microbatch_size is not None: + block.feed_forward_moe.warmup_cache(max_local_microbatch_size) self.init_method.init_feed_forward_moe( - cast(MoETransformerBlock, block).feed_forward_moe, + block.feed_forward_moe, d_model=self.d_model, block_idx=block.block_idx, num_blocks=self.n_layers, diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index f9ccf6fe9..a39169745 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -362,7 +362,11 @@ def __init__( # Materialize and init parameters. log.info("Initializing model weights...") - self.model.init_weights(max_seq_len=max_sequence_length, device=self.device) + self.model.init_weights( + max_seq_len=max_sequence_length, + max_local_microbatch_size=rank_microbatch_size, + device=self.device, + ) # Build optimizer(s). log.info("Building optimizer...") From a6b7b938f2e1bf8953fd71609c27294411441773 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 17:45:55 -0800 Subject: [PATCH 098/230] clean up --- src/olmo_core/nn/moe/parallel_mlp.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 315040a6e..ac02e6ac9 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -342,11 +342,8 @@ def parallel_forward_once( # Start the cross-device permutation asynchronously so we can # overlap communication with computation. # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) - parallel_x, parallel_x_handle = ops.all_to_all( - x, - group=self._ep_pg, - async_op=True, - ) + # ~= (num_local_experts, expert_capacity, d_model) + parallel_x, _ = ops.all_to_all(x, group=self._ep_pg) # After we do the cross-device permutation we have the tokens on the # correct device but not yet grouped by expert because we received @@ -360,10 +357,6 @@ def parallel_forward_once( ) # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - parallel_x_handle.wait() - # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) - # ~= (num_local_experts, expert_capacity, d_model) parallel_x = self.permute_and_compute( parallel_x, indices=parallel_indices.int(), From 0024a8c979c2ad3bca57651dc5250f7a3aee6e7a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 19:35:45 -0800 Subject: [PATCH 099/230] add tests for ops --- src/test/nn/moe/ops_test.py | 74 +++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 src/test/nn/moe/ops_test.py diff --git a/src/test/nn/moe/ops_test.py b/src/test/nn/moe/ops_test.py new file mode 100644 index 000000000..927b5fdb5 --- /dev/null +++ b/src/test/nn/moe/ops_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from olmo_core.nn.moe import ops + +from ...utils import requires_gpu + + +@requires_gpu +@pytest.mark.parametrize( + ("sl", "hs", "ne", "top_k"), + [ + (4, 2, 2, 1), + (4, 2, 2, 2), + (4, 2, 2, 4), + (1024, 1536, 4, 1), + (1024, 1536, 4, 2), + (1024, 1536, 4, 4), + (1024, 1536, 64, 1), + (1024, 1536, 64, 2), + (1024, 1536, 64, 4), + (1024, 1536, 128, 1), + (1024, 1536, 128, 2), + (1024, 1536, 128, 4), + (16384, 768, 4, 1), + (16384, 768, 4, 2), + (16384, 768, 4, 4), + (16384, 768, 64, 1), + (16384, 768, 64, 2), + (16384, 768, 64, 4), + (16384, 768, 128, 1), + (16384, 768, 128, 2), + (16384, 768, 128, 4), + ], +) +def test_binned_gather(sl: int, hs: int, ne: int, top_k: int): + # NOTE: Capacity factor == 1. + ec = (sl * top_k) // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + _, indices = torch.sort(top_expert) + bins = torch.cumsum(torch.histc(top_expert, ne, min=0, max=ne - 1), 0) + + def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + ec: int, + top_k: int, + ): + x = x.cpu().numpy() + indices = indices.cpu().numpy() + bins = bins.cpu().numpy() + start = 0 + out = np.zeros((ne, ec, hs)) + for i in range(ne): + end = bins[i] + for j in range(min(ec, end - start)): + index = indices[start + j] // top_k + out[i, j, :] = x[index, :] + start = end + return torch.from_numpy(out).cuda().half() + + out = ops.binned_gather(x, indices, bins, ec, top_k) + expected_out = binned_gather(x, indices, bins, ec, top_k) + assert torch.all(torch.eq(out, expected_out)) From bc61bfa2363c2369cbc078f3ffbd1d13b9cc6568 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 19:37:42 -0800 Subject: [PATCH 100/230] fix? --- src/test/nn/moe/ops_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/nn/moe/ops_test.py b/src/test/nn/moe/ops_test.py index 927b5fdb5..13d5f5575 100644 --- a/src/test/nn/moe/ops_test.py +++ b/src/test/nn/moe/ops_test.py @@ -56,19 +56,19 @@ def binned_gather( ec: int, top_k: int, ): - x = x.cpu().numpy() - indices = indices.cpu().numpy() - bins = bins.cpu().numpy() + x_np = x.cpu().numpy() + indices_np = indices.cpu().numpy() + bins_np = bins.cpu().numpy() start = 0 out = np.zeros((ne, ec, hs)) for i in range(ne): - end = bins[i] + end = bins_np[i] for j in range(min(ec, end - start)): - index = indices[start + j] // top_k - out[i, j, :] = x[index, :] + index = indices_np[start + j] // top_k + out[i, j, :] = x_np[index, :] start = end return torch.from_numpy(out).cuda().half() out = ops.binned_gather(x, indices, bins, ec, top_k) - expected_out = binned_gather(x, indices, bins, ec, top_k) + expected_out = binned_gather(x, indices.int(), bins, ec, top_k) assert torch.all(torch.eq(out, expected_out)) From 74c4e11ed7d5d843adc25de8952cc1869289267e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 19:41:20 -0800 Subject: [PATCH 101/230] fix? --- src/test/nn/moe/ops_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/ops_test.py b/src/test/nn/moe/ops_test.py index 13d5f5575..e8af27864 100644 --- a/src/test/nn/moe/ops_test.py +++ b/src/test/nn/moe/ops_test.py @@ -69,6 +69,6 @@ def binned_gather( start = end return torch.from_numpy(out).cuda().half() - out = ops.binned_gather(x, indices, bins, ec, top_k) - expected_out = binned_gather(x, indices.int(), bins, ec, top_k) + out = ops.binned_gather(x, indices.int(), bins.int(), ec, top_k) + expected_out = binned_gather(x, indices, bins, ec, top_k) assert torch.all(torch.eq(out, expected_out)) From 069a77d63a9dae1644612c393579e8de2b21864a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 19:44:42 -0800 Subject: [PATCH 102/230] add another test --- src/test/nn/moe/ops_test.py | 89 ++++++++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/ops_test.py b/src/test/nn/moe/ops_test.py index e8af27864..6e86725be 100644 --- a/src/test/nn/moe/ops_test.py +++ b/src/test/nn/moe/ops_test.py @@ -47,7 +47,8 @@ def test_binned_gather(sl: int, hs: int, ne: int, top_k: int): # Randomly assign tokens to experts. top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() _, indices = torch.sort(top_expert) - bins = torch.cumsum(torch.histc(top_expert, ne, min=0, max=ne - 1), 0) + indices = indices.int() + bins = torch.cumsum(torch.histc(top_expert, ne, min=0, max=ne - 1), 0).int() def binned_gather( x: torch.Tensor, @@ -69,6 +70,90 @@ def binned_gather( start = end return torch.from_numpy(out).cuda().half() - out = ops.binned_gather(x, indices.int(), bins.int(), ec, top_k) + out = ops.binned_gather(x, indices, bins, ec, top_k) expected_out = binned_gather(x, indices, bins, ec, top_k) assert torch.all(torch.eq(out, expected_out)) + + +@requires_gpu +@pytest.mark.parametrize( + ("sl", "hs", "ne", "top_k"), + [ + (4, 2, 2, 1), + (4, 2, 2, 2), + (4, 2, 2, 4), + (1024, 1536, 4, 1), + (1024, 1536, 4, 2), + (1024, 1536, 4, 4), + (1024, 1536, 64, 1), + (1024, 1536, 64, 2), + (1024, 1536, 64, 4), + (1024, 1536, 128, 1), + (1024, 1536, 128, 2), + (1024, 1536, 128, 4), + (16384, 768, 4, 1), + (16384, 768, 4, 2), + (16384, 768, 4, 4), + (16384, 768, 64, 1), + (16384, 768, 64, 2), + (16384, 768, 64, 4), + (16384, 768, 128, 1), + (16384, 768, 128, 2), + (16384, 768, 128, 4), + ], +) +def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int): + # NOTE: Capacity factor == 1. + ec = (sl * top_k) // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + _, indices = torch.sort(top_expert) + indices = indices.int() + bins = torch.cumsum(torch.histc(top_expert, ne, min=0, max=ne - 1), 0).int() + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + x = ops.binned_gather(x, indices, bins, ec, top_k) + + def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + x_np = x.cpu().numpy() + indices_np = indices.cpu().numpy() + weights_np = weights.cpu().numpy() + bins_np = bins.cpu().numpy() + start = 0 + out = np.zeros((sl, hs)) + for i in range(ne): + end = bins_np[i] + for j in range(min(ec, end - start)): + index = indices_np[start + j] + scale = weights_np[index] + index //= top_k + + out[index, :] += scale * x_np[i, j, :] + start = end + return torch.from_numpy(out).cuda().half() + + out = ops.binned_scatter(x, indices, weights, bins, top_k) + expected_out = binned_scatter(x, indices, weights, bins, top_k) + + # NOTE: We need to check approximate equality because the + # scatter reduce uses atomics. + assert ( + np.testing.assert_allclose( + out.cpu(), + expected_out.cpu(), + rtol=5e-3, + ) + is None + ) From 650f030ef22aedd2ed3c7d75dd80680376a2484c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 19:52:31 -0800 Subject: [PATCH 103/230] test with shared --- src/test/nn/moe/moe_test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index dfde008b3..a5be7c11b 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -16,7 +16,13 @@ build_expert_parallel_mesh, ) from olmo_core.distributed.utils import get_local_tensor -from olmo_core.nn.moe import MoEBase, MoEConfig, MoERouterConfig, MoEType +from olmo_core.nn.moe import ( + MoEBase, + MoEConfig, + MoERouterConfig, + MoEType, + SharedMLPConfig, +) from olmo_core.utils import get_default_device, seed_all from ...distributed.utils import requires_multi_gpu, run_distributed_test @@ -30,8 +36,9 @@ def init_mlp_weights(moe: MoEBase): @requires_gpu @pytest.mark.parametrize("moe_type", [MoEType.dropless, MoEType.default]) +@pytest.mark.parametrize("shared", [False, True]) @pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) -def test_moe(moe_type, dtype): +def test_moe(moe_type: MoEType, shared: bool, dtype: torch.dtype): seed_all(42) d_model = 128 @@ -39,7 +46,8 @@ def test_moe(moe_type, dtype): name=moe_type, num_experts=4, hidden_size=256, - router=MoERouterConfig(top_k=1, dtype=DType.from_pt(dtype)), + router=MoERouterConfig(top_k=1), + shared_mlp=None if not shared else SharedMLPConfig(), z_loss_weight=0.1, dtype=DType.from_pt(dtype), ) From c875ab7a5cdf33452322add9f96400881ec73cba Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 4 Feb 2025 19:54:16 -0800 Subject: [PATCH 104/230] fix --- src/olmo_core/nn/moe/shared_mlp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py index dbcf8a277..fc4f0ec04 100644 --- a/src/olmo_core/nn/moe/shared_mlp.py +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -70,7 +70,6 @@ def build( d_model=d_model, hidden_size=hidden_size, init_device=init_device, - dtype=kwargs.pop("dtype").as_pt(), ) if self.dtype is not None: kwargs["dtype"] = self.dtype.as_pt() From b925af2569f36e2d41fc2f093d6d014ce11f3d5e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 12:45:23 -0800 Subject: [PATCH 105/230] check losses --- src/olmo_core/nn/moe/parallel_mlp.py | 23 +++++++++++++---------- src/test/nn/moe/moe_test.py | 22 ++++++++++++++++++++-- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index ac02e6ac9..49e268919 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -104,12 +104,14 @@ def forward( expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - :param x: The input of shape ``(*, d_model)``. - :param expert_weights: Expert weights of shape ``(N, top_k)``. - :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. - - :returns: The output with the same shape as ``x`` and a tensor with shape ``(num_experts,)`` - containing the number of items/tokens routed to each expert. + :param x: The input of shape ``(*, d_model)``, typically ``(num_docs, seq_len, d_model)`` + such that ``num_docs x seq_len = batch_size``. + :param expert_weights: Expert weights of shape ``(batch_size, top_k)``, where ``batch_size`` + typically equals ``num_docs x seq_len``. + :param expert_indices: The indices of the top-k experts, shape ``(batch_size, top_k)``. + + :returns: The output with the same shape as ``x`` and a tensor with shape ``(num_local_experts,)`` + containing the number of items/tokens routed to each (local) expert. """ in_shape = x.size() @@ -129,10 +131,11 @@ def forward_once( expert_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - :param x: The input of shape ``(*, d_model)``. - :param expert_weights: Expert weights of shape ``(N, top_k)``, where ``N`` - typically equals ``batch_size x seq_len``. - :param expert_indices: The indices of the top-k experts, shape ``(N, top_k)``. + :param x: The input of shape ``(*, d_model)``, typically ``(num_docs, seq_len, d_model)`` + such that ``num_docs x seq_len = batch_size``. + :param expert_weights: Expert weights of shape ``(batch_size, top_k)``, where ``batch_size`` + typically equals ``num_docs x seq_len``. + :param expert_indices: The indices of the top-k experts, shape ``(batch_size, top_k)``. """ raise NotImplementedError diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index a5be7c11b..3e7eaaafc 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -76,6 +76,7 @@ def test_moe(moe_type: MoEType, shared: bool, dtype: torch.dtype): losses = moe.compute_losses(B * S) lb_loss = losses["load balancing loss"] assert math.isfinite(lb_loss.item()) + z_loss = losses["router Z loss"] assert math.isfinite(z_loss.item()) loss = lb_loss + z_loss @@ -91,6 +92,8 @@ def run_moe_with_expert_parallelism( d_model: int, batch: torch.Tensor, expected_output: torch.Tensor, + expected_lb_loss: torch.Tensor, + expected_z_loss: torch.Tensor, ): seed_all(42) @@ -125,14 +128,21 @@ def run_moe_with_expert_parallelism( torch.testing.assert_close(output, expected_output) losses = moe.compute_losses(total_tokens // ep_mesh.size()) + lb_loss = losses["load balancing loss"] assert math.isfinite(lb_loss.item()) + total_lb_loss = lb_loss.detach().clone() + dist.all_reduce(total_lb_loss) + torch.testing.assert_close(total_lb_loss, expected_lb_loss) z_loss = losses["router Z loss"] assert math.isfinite(z_loss.item()) - loss = lb_loss + z_loss + total_z_loss = z_loss.detach().clone() + dist.all_reduce(total_z_loss) + torch.testing.assert_close(total_z_loss, expected_z_loss) # Run backward pass. + loss = lb_loss + z_loss loss.backward() assert batch.grad is not None @@ -195,5 +205,13 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t run_moe_with_expert_parallelism, backend="nccl", start_method="spawn", - func_args=(tmp_path, config, d_model, batch.detach().cpu(), output.detach().cpu()), + func_args=( + tmp_path, + config, + d_model, + batch.detach().cpu(), + output.detach().cpu(), + lb_loss.detach().cpu(), + z_loss.detach().cpu(), + ), ) From 48d7e9a4163554db978c0eaacfc1a55e9f53dd8e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 12:46:39 -0800 Subject: [PATCH 106/230] fix --- src/test/nn/moe/moe_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 3e7eaaafc..54175e209 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -133,13 +133,13 @@ def run_moe_with_expert_parallelism( assert math.isfinite(lb_loss.item()) total_lb_loss = lb_loss.detach().clone() dist.all_reduce(total_lb_loss) - torch.testing.assert_close(total_lb_loss, expected_lb_loss) + torch.testing.assert_close(total_lb_loss, expected_lb_loss.to(total_lb_loss.device)) z_loss = losses["router Z loss"] assert math.isfinite(z_loss.item()) total_z_loss = z_loss.detach().clone() dist.all_reduce(total_z_loss) - torch.testing.assert_close(total_z_loss, expected_z_loss) + torch.testing.assert_close(total_z_loss, expected_z_loss.to(total_z_loss.device)) # Run backward pass. loss = lb_loss + z_loss From d47c29df3a748f6436514a9d9ee47265d5aee790 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 12:51:46 -0800 Subject: [PATCH 107/230] fix? --- src/test/nn/moe/moe_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 54175e209..38aefffb2 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -97,7 +97,7 @@ def run_moe_with_expert_parallelism( ): seed_all(42) - ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=min(dist.get_world_size(), 2))) + ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) moe = config.build(d_model=d_model, num_layers=1, init_device="meta") moe.apply_ep(ep_mesh) @@ -129,15 +129,17 @@ def run_moe_with_expert_parallelism( losses = moe.compute_losses(total_tokens // ep_mesh.size()) + # Check load balancing loss. lb_loss = losses["load balancing loss"] assert math.isfinite(lb_loss.item()) - total_lb_loss = lb_loss.detach().clone() + total_lb_loss = lb_loss.detach() / dist.get_world_size() dist.all_reduce(total_lb_loss) torch.testing.assert_close(total_lb_loss, expected_lb_loss.to(total_lb_loss.device)) + # Check Z loss. z_loss = losses["router Z loss"] assert math.isfinite(z_loss.item()) - total_z_loss = z_loss.detach().clone() + total_z_loss = z_loss.detach() / dist.get_world_size() dist.all_reduce(total_z_loss) torch.testing.assert_close(total_z_loss, expected_z_loss.to(total_z_loss.device)) From 0d1ea1f5d558e2bd7b0ea8d6369cedac59f9f154 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 12:57:26 -0800 Subject: [PATCH 108/230] clean up --- src/test/nn/moe/moe_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 38aefffb2..3ae20261d 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -107,7 +107,6 @@ def run_moe_with_expert_parallelism( load_model_and_optim_state(checkpoint_dir, moe) # Split batch and expected output across process group. - total_tokens = batch.shape[0] * batch.shape[1] batch = get_local_tensor( distribute_tensor( batch.to(device=get_default_device()), device_mesh=ep_mesh, placements=(Shard(0),) @@ -127,7 +126,7 @@ def run_moe_with_expert_parallelism( assert output.shape == batch.shape torch.testing.assert_close(output, expected_output) - losses = moe.compute_losses(total_tokens // ep_mesh.size()) + losses = moe.compute_losses(batch.shape[0] * batch.shape[1]) # Check load balancing loss. lb_loss = losses["load balancing loss"] From 4d621d88281df49fbd0963886d9390b7e0e96620 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 14:23:10 -0800 Subject: [PATCH 109/230] clean up --- src/olmo_core/nn/moe/loss.py | 33 ++++++++----- src/olmo_core/nn/moe/moe.py | 21 +++------ src/olmo_core/nn/moe/router.py | 17 +++++-- src/olmo_core/nn/transformer/block.py | 8 ++-- src/olmo_core/nn/transformer/model.py | 17 +++++-- .../train/train_module/transformer.py | 46 ++++++++----------- src/test/nn/moe/moe_test.py | 16 ++++--- 7 files changed, 86 insertions(+), 72 deletions(-) diff --git a/src/olmo_core/nn/moe/loss.py b/src/olmo_core/nn/moe/loss.py index 6b622b1e7..566a0c443 100644 --- a/src/olmo_core/nn/moe/loss.py +++ b/src/olmo_core/nn/moe/loss.py @@ -8,7 +8,14 @@ class MoELoss(metaclass=ABCMeta): @abstractmethod - def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): + def update( + self, + *, + expert_logits: torch.Tensor, + expert_scores: torch.Tensor, + batch_size_per_expert: torch.Tensor, + **kwargs, + ): raise NotImplementedError @abstractmethod @@ -27,18 +34,21 @@ class MoELoadBalancingLoss(MoELoss): Implements the load balancing loss from Switch Transformers. """ - def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int, top_k: int): + def __init__(self, *, loss_weight: float, num_experts: int, top_k: int): self.loss_weight = loss_weight - self.num_layers = num_layers self.num_experts = num_experts self.top_k = top_k self.loss: Optional[torch.Tensor] = None - def update(self, expert_logits: torch.Tensor, *, batch_size_per_expert: torch.Tensor, **kwargs): + def update( + self, + *, + expert_scores: torch.Tensor, + batch_size_per_expert: torch.Tensor, + **kwargs, + ): del kwargs - # shape: (N, num_experts) - expert_scores = expert_logits.softmax(dim=-1) - # shape: (num_experts,) + # shape: (batch_size, num_local_experts) -> (num_local_experts,) expert_scores = expert_scores.mean(dim=0) loss = torch.dot(batch_size_per_expert.type_as(expert_scores), expert_scores) if self.loss is None: @@ -54,7 +64,7 @@ def compute( raise RuntimeError( f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" ) - scale = (self.num_experts * self.loss_weight) / (self.num_layers * total_bz * self.top_k) + scale = (self.num_experts * self.loss_weight) / (total_bz * self.top_k) lb_loss = scale * self.loss if reset: self.reset() @@ -65,13 +75,12 @@ def reset(self): class MoERouterZLoss(MoELoss): - def __init__(self, *, loss_weight: float, num_layers: int, num_experts: int): + def __init__(self, *, loss_weight: float, num_experts: int): self.loss_weight = loss_weight - self.num_layers = num_layers self.num_experts = num_experts self.loss: Optional[torch.Tensor] = None - def update(self, expert_logits: torch.Tensor, **kwargs): + def update(self, *, expert_logits: torch.Tensor, **kwargs): del kwargs loss = torch.logsumexp(expert_logits, dim=-1).square().sum() if self.loss is None: @@ -87,7 +96,7 @@ def compute( raise RuntimeError( f"'{self.__class__.__name__}.update()' needs to be called before '.compute()'" ) - scale = self.loss_weight / (self.num_layers * total_bz) + scale = self.loss_weight / total_bz lb_loss = scale * self.loss if reset: self.reset() diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 08b485c86..4aa1dae2f 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -72,7 +72,6 @@ def build( self, d_model: int, *, - num_layers: int, init_device: str = "cpu", cache: Optional[BufferCache] = None, ) -> "MoEBase": @@ -80,7 +79,6 @@ def build( kwargs.pop("name") kwargs.update( d_model=d_model, - num_layers=num_layers, init_device=init_device, dtype=kwargs.pop("dtype").as_pt(), cache=cache, @@ -111,7 +109,6 @@ def __init__( num_experts: int, hidden_size: int, router: MoERouterConfig, - num_layers: int, shared_mlp: Optional[SharedMLPConfig] = None, init_device: str = "cpu", lb_loss_weight: Optional[float] = None, @@ -136,23 +133,17 @@ def __init__( if shared_mlp is None else shared_mlp.build(d_model, hidden_size, dtype=dtype, init_device=init_device) ) - self.num_layers = num_layers self.losses: List[MoELoss] = [] if lb_loss_weight is not None: self.losses.append( MoELoadBalancingLoss( loss_weight=lb_loss_weight, - num_layers=num_layers, num_experts=num_experts, top_k=self.router.top_k, ) ) if z_loss_weight is not None: - self.losses.append( - MoERouterZLoss( - loss_weight=z_loss_weight, num_layers=num_layers, num_experts=num_experts - ) - ) + self.losses.append(MoERouterZLoss(loss_weight=z_loss_weight, num_experts=num_experts)) def warmup_cache(self, max_local_microbatch_size: int): self.experts.warmup_cache(max_local_microbatch_size) @@ -191,7 +182,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ - expert_logits, expert_weights, exper_indices = self.router(x) + expert_logits, expert_scores, expert_weights, exper_indices = self.router(x) out, batch_size_per_expert = self.experts(x, expert_weights, exper_indices) if self.shared_experts is not None: out = self.shared_experts(x, out, self.router.top_k) @@ -199,7 +190,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and self.losses: expert_logits = expert_logits.float() for loss_fn in self.losses: - loss_fn.update(expert_logits, batch_size_per_expert=batch_size_per_expert) + loss_fn.update( + expert_logits=expert_logits, + expert_scores=expert_scores, + batch_size_per_expert=batch_size_per_expert, + ) return out @@ -244,7 +239,6 @@ def __init__( num_experts: int, hidden_size: int, router: MoERouterConfig, - num_layers: int, shared_mlp: Optional[SharedMLPConfig] = None, capacity_factor: float = 1.2, init_device: str = "cpu", @@ -258,7 +252,6 @@ def __init__( num_experts=num_experts, hidden_size=hidden_size, router=router, - num_layers=num_layers, shared_mlp=shared_mlp, init_device=init_device, lb_loss_weight=lb_loss_weight, diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 14f7a9c6b..3b6ca0949 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -165,13 +165,17 @@ def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: """ raise NotImplementedError - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Given the input ``x`` of shape ``(*, d_model)``, compute the experts assignment. - :returns: The logits of shape ``(N, num_experts)``, the expert weights - of shape ``(N, top_k)``, and the expert indices of shape ``(N, top_k)``. + :returns: The unnormalized scores (logits) of shape ``(N, num_experts)``, + the normalized scores of shape ``(N, num_experts)``, + the expert weights of shape ``(N, top_k)``, + and the expert indices of shape ``(N, top_k)``. """ # shape: (batch_size, seq_len, d_model) x = self.jitter(x) @@ -179,8 +183,11 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te # shape: (batch_size * seq_len, num_experts) logits = self.get_expert_logits(x.view(-1, self.d_model)) + # shape: (batch_size * seq_len, num_experts) + scores = logits.softmax(dim=-1) + # shape: (batch_size * seq_len, top_k) - expert_weights, expert_indices = self.get_top_k(logits) + expert_weights, expert_indices = self.get_top_k(scores) if self.normalize_expert_weights is not None: expert_weights.div_( @@ -195,7 +202,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te if self.uniform_expert_assignment: expert_indices = _uniform_expert_assignment(expert_indices, self.num_experts) - return logits, expert_weights, expert_indices + return logits, scores, expert_weights, expert_indices class MoELinearRouter(MoERouter): diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 016cff145..1321f8dee 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -90,7 +90,6 @@ def build( *, d_model: int, block_idx: int, - num_blocks: int, init_device: str = "cpu", cache: Optional[BufferCache] = None, ) -> "TransformerBlockBase": @@ -111,9 +110,9 @@ def build( elif self.name == TransformerBlockType.normalized: return NormalizedTransformerBlock(**kwargs) elif self.name == TransformerBlockType.moe: - return MoETransformerBlock(num_blocks=num_blocks, **kwargs) + return MoETransformerBlock(**kwargs) elif self.name == TransformerBlockType.moe_reordered_norm: - return MoEReorderedNormTransformerBlock(num_blocks=num_blocks, **kwargs) + return MoEReorderedNormTransformerBlock(**kwargs) else: raise NotImplementedError(self.name) except TypeError as e: @@ -359,7 +358,6 @@ def __init__( attention: AttentionConfig, feed_forward_moe: MoEConfig, layer_norm: LayerNormConfig, - num_blocks: int, dropout: float = 0.0, init_device: str = "cpu", cache: Optional[BufferCache] = None, @@ -370,7 +368,7 @@ def __init__( self.attention = attention.build(d_model, init_device=init_device, cache=cache) self.attention_norm = layer_norm.build(d_model, init_device=init_device) self.feed_forward_moe = feed_forward_moe.build( - d_model=d_model, num_layers=num_blocks, init_device=init_device, cache=cache + d_model=d_model, init_device=init_device, cache=cache ) self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index f586a481e..1e9b0f8dc 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -116,7 +116,6 @@ def __init__( block_ = block.build( d_model=d_model, block_idx=block_idx, - num_blocks=n_layers, init_device=init_device, cache=cache, ) @@ -139,6 +138,15 @@ def __init__( def _validate_block(self, block: TransformerBlockBase): del block + def compute_auxiliary_losses( + self, total_bz: Union[int, torch.Tensor], reset: bool = True + ) -> Dict[str, torch.Tensor]: + del total_bz, reset + return {} + + def reset_auxiliary_losses(self): + pass + @property def is_moe(self) -> bool: return False @@ -625,7 +633,7 @@ def _validate_block(self, block: TransformerBlockBase): f"'{self.__class__.__name__}' requires a '{MoETransformerBlock.__name__}' block" ) - def compute_losses( + def compute_auxiliary_losses( self, total_bz: Union[int, torch.Tensor], reset: bool = True ) -> Dict[str, torch.Tensor]: out: Dict[str, torch.Tensor] = {} @@ -633,15 +641,16 @@ def compute_losses( for loss_name, loss_val in ( cast(MoETransformerBlock, block).compute_losses(total_bz, reset=reset).items() ): + loss_val.div_(self.n_layers) if loss_name in out: out[loss_name] += loss_val else: out[loss_name] = loss_val return out - def reset_losses(self): + def reset_auxiliary_losses(self): for block in self.blocks.values(): - cast(MoETransformer, block).reset_losses() + cast(MoETransformerBlock, block).reset_losses() def forward( self, diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index a39169745..eb07aacbe 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -529,9 +529,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): z_batch_loss: Optional[torch.Tensor] = None if self.z_loss_multiplier is not None: z_batch_loss = move_to_device(torch.tensor(0.0), self.device) - moe_batch_losses: Optional[Dict[str, torch.Tensor]] = None - if self.model.is_moe: - moe_batch_losses = {} + auxiliary_batch_losses: Dict[str, torch.Tensor] = {} # Split into micro-batches. if self.rank_microbatch_size < (seq_len := batch["input_ids"].shape[1]): @@ -561,20 +559,18 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): z_batch_loss += z_loss del z_loss - # Optionally get MoE losses and update the total batch MoE losses. - if self.model.is_moe: - assert moe_batch_losses is not None - moe_losses = cast(MoETransformer, self.model).compute_losses( - batch_num_tokens_for_loss, reset=True - ) - for loss_name, loss_val in moe_losses.items(): - loss += loss_val - loss_val = get_local_tensor(loss_val.detach()) - if loss_name in moe_batch_losses: - moe_batch_losses[loss_name] += loss_val - else: - moe_batch_losses[loss_name] = loss_val - del moe_losses + # Optionally get model auxiliary losses and update the total batch auxiliary losses. + auxiliary_losses = self.model.compute_auxiliary_losses( + batch_num_tokens_for_loss, reset=True + ) + for loss_name, loss_val in auxiliary_losses.items(): + loss += loss_val + loss_val = get_local_tensor(loss_val.detach()) + if loss_name in auxiliary_batch_losses: + auxiliary_batch_losses[loss_name] += loss_val + else: + auxiliary_batch_losses[loss_name] = loss_val + del auxiliary_losses # Run backward pass. loss.backward() @@ -594,15 +590,13 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): ReduceType.mean, namespace="train", ) - if self.model.is_moe: - assert moe_batch_losses is not None - for loss_name, loss_val in moe_batch_losses.items(): - self.record_metric( - loss_name, - loss_val, - ReduceType.mean, - namespace="train", - ) + for loss_name, loss_val in auxiliary_batch_losses.items(): + self.record_metric( + loss_name, + loss_val, + ReduceType.mean, + namespace="train", + ) if isinstance(self.optim, SkipStepOptimizer): self.optim.latest_loss = ce_batch_loss diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 3ae20261d..66713e407 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -51,7 +51,7 @@ def test_moe(moe_type: MoEType, shared: bool, dtype: torch.dtype): z_loss_weight=0.1, dtype=DType.from_pt(dtype), ) - moe = config.build(d_model=d_model, num_layers=1, init_device="cuda") + moe = config.build(d_model=d_model, init_device="cuda") init_mlp_weights(moe) # Check num params calculation. @@ -99,7 +99,7 @@ def run_moe_with_expert_parallelism( ep_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) - moe = config.build(d_model=d_model, num_layers=1, init_device="meta") + moe = config.build(d_model=d_model, init_device="meta") moe.apply_ep(ep_mesh) moe.to_empty(device=get_default_device()) @@ -131,9 +131,13 @@ def run_moe_with_expert_parallelism( # Check load balancing loss. lb_loss = losses["load balancing loss"] assert math.isfinite(lb_loss.item()) - total_lb_loss = lb_loss.detach() / dist.get_world_size() - dist.all_reduce(total_lb_loss) - torch.testing.assert_close(total_lb_loss, expected_lb_loss.to(total_lb_loss.device)) + + # NOTE: This particular load-balancing loss may differ in distributed case, or even with + # gradient accumulation due to ``batch_size_per_expert`` being the local. + # total_lb_loss = lb_loss.detach() / dist.get_world_size() + # dist.all_reduce(total_lb_loss) + # torch.testing.assert_close(total_lb_loss, expected_lb_loss.to(total_lb_loss.device)) + del expected_lb_loss # Check Z loss. z_loss = losses["router Z loss"] @@ -174,7 +178,7 @@ def test_moe_with_expert_parallelism(tmp_path: Path, moe_type: MoEType, dtype: t z_loss_weight=0.1, dtype=DType.from_pt(dtype), ) - moe = config.build(d_model=d_model, num_layers=1, init_device="cpu") + moe = config.build(d_model=d_model, init_device="cpu") moe.to(device=device) init_mlp_weights(moe) From 5db88503a6597550f98c76ac9a9183008a1ab0e7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 14:44:53 -0800 Subject: [PATCH 110/230] comments --- src/olmo_core/train/train_module/transformer_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 33bfe290e..8470bb86f 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -448,7 +448,7 @@ def __init__( for model in self.model_parts: if model.is_moe: - # TODO (epwalsh): need to handle the internal MoE losses correctly. + # TODO (epwalsh): need to handle the MoE auxiliary losses correctly. raise NotImplementedError( "Pipeline parallelism with MoE's is currently not supported" ) @@ -669,7 +669,6 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): self.record_ce_loss( self._ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum ) - if self.z_loss_multiplier is not None: if self._z_batch_loss is None: self.record_metric("Z loss", 0.0, ReduceType.sum, namespace="train") @@ -680,6 +679,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): ReduceType.sum, namespace="train", ) + # TODO: handle model auxiliary losses, like with MoE. for optim in self.optimizers: if isinstance(optim, SkipStepOptimizer): From 3fd2c641ae48f931f6613ccde612823b6081066f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 15:20:07 -0800 Subject: [PATCH 111/230] add batched histc --- src/olmo_core/nn/moe/ops.py | 12 ++++++++++++ src/olmo_core/nn/moe/parallel_mlp.py | 2 +- src/test/nn/moe/ops_test.py | 9 ++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index ab9b4fdf3..831f7f3e7 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -4,6 +4,8 @@ import torch import torch.distributed as dist +from olmo_core.utils import move_to_device + def _is_eligible(x): return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) @@ -314,3 +316,13 @@ def sum_tensor(x: torch.Tensor, dim: int = 0) -> torch.Tensor: if x.shape[dim] == 1: return x.squeeze(dim=dim) return x.sum(dim=dim) + + +def batched_histc(x: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + A batched version of ``torch.histc``. + """ + hist = move_to_device(torch.zeros((*x.shape[:-1], num_classes), dtype=x.dtype), x.device) + ones = move_to_device(torch.tensor(1, dtype=x.dtype), x.device).expand_as(x) + hist.scatter_add_(-1, ((x * num_classes) // (x.max() + 1)).long(), ones) + return hist diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 49e268919..0490000ea 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -85,7 +85,7 @@ def indices_and_bins( # Sort the expert ids to produce the scatter/gather # indices for the permutation. - # shape: (N,), (N,) + # shape: (batch_size,), (batch_size,) # TODO: for non-dropless MoE, should do secondary sort by expert weight so we drop tokens # with the lowest expert weight. bin_ids, indices = torch.sort(expert_indices) diff --git a/src/test/nn/moe/ops_test.py b/src/test/nn/moe/ops_test.py index 6e86725be..4b89067ec 100644 --- a/src/test/nn/moe/ops_test.py +++ b/src/test/nn/moe/ops_test.py @@ -7,7 +7,7 @@ from olmo_core.nn.moe import ops -from ...utils import requires_gpu +from ...utils import DEVICES, requires_gpu @requires_gpu @@ -157,3 +157,10 @@ def binned_scatter( ) is None ) + + +@pytest.mark.parametrize("device", DEVICES) +def test_batched_histc(device: torch.device): + x = torch.tensor([[0, 1, 1], [2, 0, 0]], device=device) + hist = ops.batched_histc(x, 3) + torch.testing.assert_close(hist, torch.tensor([[1, 2, 0], [2, 0, 1]], device=device)) From 91365f36009b23bb5300df03fb1302c9df0ac369 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 16:08:04 -0800 Subject: [PATCH 112/230] stuff --- src/olmo_core/nn/moe/parallel_mlp.py | 6 ++++-- src/test/nn/moe/ops_test.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 0490000ea..02355626d 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -77,6 +77,8 @@ def indices_and_bins( # Histogram the expert ids to identify the number of # items/tokens routed to each expert. # shape: (num_experts,), LongTensor + # NOTE: if we wanted to keep the batch dimension here like for sequence-level load balancing + # loss, we could use `opts.batched_histc`. batch_size_per_expert = torch.histc( expert_indices, bins=self.num_experts, min=0, max=self.num_experts - 1 ) @@ -324,7 +326,7 @@ def parallel_forward_once( expert_indices = expert_indices.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(expert_indices) + indices, bin_ids, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, local_expert_capacity, d_model) @@ -379,7 +381,7 @@ def parallel_forward_once( # Un-permute locally to setup for the next series of operations. x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() + return x, batch_size_per_expert.flatten() def permute_and_compute( self, diff --git a/src/test/nn/moe/ops_test.py b/src/test/nn/moe/ops_test.py index 4b89067ec..d45f0a35f 100644 --- a/src/test/nn/moe/ops_test.py +++ b/src/test/nn/moe/ops_test.py @@ -102,7 +102,7 @@ def binned_gather( (16384, 768, 128, 4), ], ) -def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int): +def test_binned_scatter(sl: int, hs: int, ne: int, top_k: int): # NOTE: Capacity factor == 1. ec = (sl * top_k) // ne From ba2f6ad3de5302fa997b981128090decf9e964f9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 5 Feb 2025 16:27:11 -0800 Subject: [PATCH 113/230] add config builder for MoE --- src/olmo_core/nn/transformer/config.py | 31 +++++++++++++++++++++++++- src/scripts/train/OLMoE-1B-7B.py | 25 ++------------------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 2d2781d39..475a3d0c9 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -8,6 +8,7 @@ from ..feed_forward import FeedForwardConfig, FeedForwardType from ..layer_norm import LayerNormConfig, LayerNormType from ..lm_head import LMHeadConfig, LMHeadType +from ..moe import MoEConfig, MoERouterConfig, MoEType from ..rope import RoPEConfig, RoPEScalingConfig, RoPEType from .block import TransformerBlockConfig, TransformerBlockType from .init import InitMethod @@ -280,6 +281,27 @@ def olmo2_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": **kwargs, ) + @classmethod + def olmoe_1B_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + d_model = kwargs.pop("d_model", 2048) + config = cls.olmo2_1B( + d_model=d_model, + vocab_size=vocab_size, + n_layers=kwargs.pop("n_layers", 16), + n_heads=kwargs.pop("n_heads", 16), + name=kwargs.pop("name", TransformerType.moe), + block_name=kwargs.pop("block_name", TransformerBlockType.moe_reordered_norm), + feed_forward_moe=MoEConfig( + name=MoEType.dropless, + num_experts=64, + hidden_size=int(0.5 * d_model), + router=MoERouterConfig(top_k=8, bias=False), + lb_loss_weight=0.01, + z_loss_weight=0.001, + ), + ) + return config + @classmethod def olmo2_3B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( @@ -548,6 +570,8 @@ def llama_like( block_name: TransformerBlockType = TransformerBlockType.default, dtype: DType = DType.float32, rope_scaling: Optional[RoPEScalingConfig] = None, + feed_forward: Optional[FeedForwardConfig] = None, + feed_forward_moe: Optional[MoEConfig] = None, **kwargs, ) -> "TransformerConfig": """ @@ -583,6 +607,10 @@ def llama_like( att_type = AttentionType.fused rope_type = RoPEType.fused + # Feed-forward. + if feed_forward is None and feed_forward_moe is None: + feed_forward = FeedForwardConfig(hidden_size=hidden_size, bias=False, dtype=dtype) + # Configure blocks. block = TransformerBlockConfig( name=block_name, @@ -596,7 +624,8 @@ def llama_like( use_flash=use_flash, dtype=dtype, ), - feed_forward=FeedForwardConfig(hidden_size=hidden_size, bias=False, dtype=dtype), + feed_forward=feed_forward, + feed_forward_moe=feed_forward_moe, layer_norm=layer_norm, ) diff --git a/src/scripts/train/OLMoE-1B-7B.py b/src/scripts/train/OLMoE-1B-7B.py index 3824bac38..f4cd5155d 100644 --- a/src/scripts/train/OLMoE-1B-7B.py +++ b/src/scripts/train/OLMoE-1B-7B.py @@ -6,12 +6,7 @@ from olmo_core.config import DType from olmo_core.distributed.parallel import DataParallelType from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.moe import MoEConfig, MoERouterConfig, MoEType -from olmo_core.nn.transformer import ( - TransformerBlockType, - TransformerConfig, - TransformerType, -) +from olmo_core.nn.transformer import TransformerConfig from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride from olmo_core.train import TrainerConfig from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback @@ -23,23 +18,7 @@ def build_model_config(common: CommonComponents) -> TransformerConfig: - model_config = TransformerConfig.olmo2_1B( - vocab_size=common.tokenizer.padded_vocab_size(), - n_layers=16, - n_heads=16, - block_name=TransformerBlockType.moe_reordered_norm, - ) - model_config.name = TransformerType.moe - model_config.block.feed_forward = None - model_config.block.feed_forward_moe = MoEConfig( - name=MoEType.dropless, - num_experts=64, - hidden_size=int(0.5 * model_config.d_model), - router=MoERouterConfig(top_k=8, bias=False), - lb_loss_weight=0.01, - z_loss_weight=0.001, - ) - return model_config + return TransformerConfig.olmoe_1B_7B(vocab_size=common.tokenizer.padded_vocab_size()) def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: From 108fd69743b185dd12f66aeddcc4704eb5e8ba14 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 12:58:02 -0800 Subject: [PATCH 114/230] Add SmallMoE config --- src/examples/llama/train.py | 3 +- src/olmo_core/nn/moe/shared_mlp.py | 4 +- src/olmo_core/nn/transformer/config.py | 70 ++++++++++++++++++-------- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/src/examples/llama/train.py b/src/examples/llama/train.py index f7a7a4121..171a51a40 100644 --- a/src/examples/llama/train.py +++ b/src/examples/llama/train.py @@ -56,7 +56,8 @@ class ExperimentConfig(Config): def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: tokenizer_config = TokenizerConfig.gpt2() - model_config = TransformerConfig.llama2_271M( + # model_config = TransformerConfig.llama2_271M( + model_config = TransformerConfig.smallmoe( vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128 ) diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py index fc4f0ec04..9cff98ef9 100644 --- a/src/olmo_core/nn/moe/shared_mlp.py +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -33,6 +33,7 @@ class SharedMLPConfig(Config): The name of the implementation. """ weighted_sum: bool = True + hidden_size: Optional[int] = None bias: bool = True dtype: Optional[DType] = None @@ -44,6 +45,7 @@ def num_params(self, d_model: int, hidden_size: int) -> int: """ params = 0 + hidden_size = self.hidden_size or hidden_size params += 3 * d_model * hidden_size if self.bias: params += 2 * hidden_size + d_model @@ -68,9 +70,9 @@ def build( kwargs.pop("name") kwargs.update( d_model=d_model, - hidden_size=hidden_size, init_device=init_device, ) + kwargs.setdefault("hidden_size", hidden_size) if self.dtype is not None: kwargs["dtype"] = self.dtype.as_pt() elif dtype is not None: diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 475a3d0c9..40893415e 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -8,7 +8,7 @@ from ..feed_forward import FeedForwardConfig, FeedForwardType from ..layer_norm import LayerNormConfig, LayerNormType from ..lm_head import LMHeadConfig, LMHeadType -from ..moe import MoEConfig, MoERouterConfig, MoEType +from ..moe import MoEConfig, MoERouterConfig, MoEType, SharedMLPConfig from ..rope import RoPEConfig, RoPEScalingConfig, RoPEType from .block import TransformerBlockConfig, TransformerBlockType from .init import InitMethod @@ -281,27 +281,6 @@ def olmo2_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": **kwargs, ) - @classmethod - def olmoe_1B_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": - d_model = kwargs.pop("d_model", 2048) - config = cls.olmo2_1B( - d_model=d_model, - vocab_size=vocab_size, - n_layers=kwargs.pop("n_layers", 16), - n_heads=kwargs.pop("n_heads", 16), - name=kwargs.pop("name", TransformerType.moe), - block_name=kwargs.pop("block_name", TransformerBlockType.moe_reordered_norm), - feed_forward_moe=MoEConfig( - name=MoEType.dropless, - num_experts=64, - hidden_size=int(0.5 * d_model), - router=MoERouterConfig(top_k=8, bias=False), - lb_loss_weight=0.01, - z_loss_weight=0.001, - ), - ) - return config - @classmethod def olmo2_3B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( @@ -366,6 +345,53 @@ def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": **kwargs, ) + @classmethod + def smallmoe(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + d_model = kwargs.pop("d_model", 768) + return cls.llama_like( + d_model=d_model, + vocab_size=vocab_size, + n_layers=kwargs.pop("n_layers", 12), + n_heads=kwargs.pop("n_heads", 12), + name=kwargs.pop("name", TransformerType.moe), + block_name=kwargs.pop("block_name", TransformerBlockType.moe_reordered_norm), + qk_norm=kwargs.pop("qk_norm", True), + rope_theta=kwargs.pop("rope_theta", 500_000), + layer_norm_eps=1e-6, + feed_forward_moe=MoEConfig( + name=MoEType.default, + num_experts=32, + hidden_size=int(0.5 * d_model), + router=MoERouterConfig(top_k=8, bias=False), + shared_mlp=SharedMLPConfig(hidden_size=d_model * 2, bias=False), + lb_loss_weight=0.01, + z_loss_weight=0.001, + ), + ) + + @classmethod + def olmoe_1B_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + d_model = kwargs.pop("d_model", 2048) + return cls.llama_like( + d_model=d_model, + vocab_size=vocab_size, + n_layers=kwargs.pop("n_layers", 16), + n_heads=kwargs.pop("n_heads", 16), + name=kwargs.pop("name", TransformerType.moe), + block_name=kwargs.pop("block_name", TransformerBlockType.moe_reordered_norm), + qk_norm=kwargs.pop("qk_norm", True), + rope_theta=kwargs.pop("rope_theta", 500_000), + layer_norm_eps=1e-6, + feed_forward_moe=MoEConfig( + name=MoEType.dropless, + num_experts=64, + hidden_size=int(0.5 * d_model), + router=MoERouterConfig(top_k=8, bias=False), + lb_loss_weight=0.01, + z_loss_weight=0.001, + ), + ) + @classmethod def ngpt_271M(cls, vocab_size: int, **kwargs) -> "TransformerConfig": """ From 8cd7ed17b7e4db57298e911aad325fd4b9e9fa5f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 15:43:05 -0800 Subject: [PATCH 115/230] use replicate with EP --- .../distributed/parallel/__init__.py | 4 +-- src/olmo_core/nn/moe/mlp.py | 28 ++++++++++++++----- src/olmo_core/nn/moe/parallel_mlp.py | 22 +++++++-------- 3 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 9e242d65d..055f9b170 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -192,9 +192,7 @@ def build_expert_parallel_mesh( for i, (name, dim) in enumerate(zip(names, dims)): log.info(f" > dimension {i}, size={dim}, name={name}") - return init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names))[ - MeshDimName.ep_shard - ] + return init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) def get_dp_mesh( diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index e75613768..6b0458017 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -2,10 +2,11 @@ from typing import Any, Callable, Optional import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.distributed import DeviceMesh -from torch.distributed.tensor import Shard, distribute_tensor +from torch.distributed.tensor import Replicate, Shard, distribute_tensor from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError @@ -45,6 +46,7 @@ def __init__( self.gradient_scale: Optional[float] = None self.num_local_experts = num_experts self.hidden_sharding_degree = 1 + self.ep_pg: Optional[dist.ProcessGroup] = None def scale_grad(self, w: torch.Tensor) -> torch.Tensor: if self.gradient_scale is None: @@ -55,9 +57,18 @@ def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ - if ep_mesh.ndim > 1: - raise RuntimeError("local expert parallel sub-mesh must be 1-dimensional") - num_shards = ep_mesh.size() + from torch.distributed._composable.replicate import replicate + + if ep_mesh.ndim != 1: + raise RuntimeError("expert parallel mesh must be 2-dimensional") + if not ep_mesh.mesh_dim_names: + raise RuntimeError("expert parallel mesh must have named dimensions") + + replicate_dim_name, shard_dim_name = ep_mesh.mesh_dim_names + + self.ep_pg = ep_mesh[shard_dim_name].get_group() + num_shards = ep_mesh[shard_dim_name].size() + if self.num_experts % num_shards != 0: raise OLMoConfigurationError( f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel shard degree ({num_shards})." @@ -66,9 +77,12 @@ def apply_ep(self, ep_mesh: DeviceMesh): self.num_local_experts = self.num_experts // num_shards self.gradient_scale = 1.0 / num_shards - self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)]))) # type: ignore - self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)]))) # type: ignore - self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)]))) # type: ignore + placements = [Replicate(), Shard(0)] + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, placements))) # type: ignore + self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, placements))) # type: ignore + self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, placements))) # type: ignore + + replicate(self, device_mesh=ep_mesh[replicate_dim_name]) class MoEMLP(MoEMLPBase): diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 02355626d..e3939d330 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -30,8 +30,6 @@ def __init__(self, *, mlp: MoEMLPBase, top_k: int, cache: Optional[BufferCache] self.top_k = top_k self._cache = cache or BufferCache() self._expert_parallel_enabled: bool = False - self._ep_mesh: Optional[DeviceMesh] = None - self._ep_pg: Optional[dist.ProcessGroup] = None def warmup_cache(self, max_local_microbatch_size: int): del max_local_microbatch_size @@ -54,19 +52,21 @@ def hidden_sharding_degree(self) -> int: @property def ep_world_size(self) -> int: - if self._ep_pg is not None: - return get_world_size(self._ep_pg) + if self.ep_pg is not None: + return get_world_size(self.ep_pg) else: return 1 + @property + def ep_pg(self) -> Optional[dist.ProcessGroup]: + return self.mlp.ep_pg + def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ self.mlp.apply_ep(ep_mesh) self._expert_parallel_enabled = True - self._ep_mesh = ep_mesh - self._ep_pg = None if ep_mesh is None else ep_mesh.get_group() def indices_and_bins( self, expert_indices: torch.Tensor @@ -348,7 +348,7 @@ def parallel_forward_once( # overlap communication with computation. # shape: (num_local_experts * ep_world_size, local_expert_capacity, d_model) # ~= (num_local_experts, expert_capacity, d_model) - parallel_x, _ = ops.all_to_all(x, group=self._ep_pg) + parallel_x, _ = ops.all_to_all(x, group=self.ep_pg) # After we do the cross-device permutation we have the tokens on the # correct device but not yet grouped by expert because we received @@ -372,7 +372,7 @@ def parallel_forward_once( ) # Un-permute the tokens across the devices. - x, _ = ops.all_to_all(parallel_x, group=self._ep_pg) + x, _ = ops.all_to_all(parallel_x, group=self.ep_pg) # Reduce along the hidden sharding to get the final outputs. if self.hidden_sharding_degree > 1: @@ -481,7 +481,7 @@ def parallel_forward_once( tpe_handle = dist.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, - group=self._ep_pg, + group=self.ep_pg, async_op=True, ) assert tpe_handle is not None @@ -520,7 +520,7 @@ def parallel_forward_once( x, recv_counts, send_counts, - group=self._ep_pg, + group=self.ep_pg, async_op=True, ) @@ -576,7 +576,7 @@ def parallel_forward_once( parallel_x, send_counts, recv_counts, - group=self._ep_pg, + group=self.ep_pg, ) # Reduce along the hidden sharding to get the final outputs. From 10e845c35d7289e91cae406be15ff851d2591819 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 15:44:47 -0800 Subject: [PATCH 116/230] fix? --- src/olmo_core/distributed/parallel/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 055f9b170..8b96e6d16 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -170,8 +170,8 @@ def build_expert_parallel_mesh( device_type = device_type or get_default_device().type world_size = get_world_size() - if ep_config.degree == world_size: - return init_device_mesh(device_type, (world_size,), mesh_dim_names=(MeshDimName.ep_shard,)) + # if ep_config.degree == world_size: + # return init_device_mesh(device_type, (world_size,), mesh_dim_names=(MeshDimName.ep_shard,)) # Build up mesh dimensions. names: List[str] = [] From c130305c5fe3df77f39db97f6d1aa6f5ab70daab Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 15:45:32 -0800 Subject: [PATCH 117/230] dumb --- src/olmo_core/nn/moe/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 6b0458017..18ae89b07 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -59,7 +59,7 @@ def apply_ep(self, ep_mesh: DeviceMesh): """ from torch.distributed._composable.replicate import replicate - if ep_mesh.ndim != 1: + if ep_mesh.ndim != 2: raise RuntimeError("expert parallel mesh must be 2-dimensional") if not ep_mesh.mesh_dim_names: raise RuntimeError("expert parallel mesh must have named dimensions") From 2c6bbc7312530a41434713b4b35ce5c4b2508114 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 15:47:57 -0800 Subject: [PATCH 118/230] fix test? --- src/test/nn/moe/moe_test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/test/nn/moe/moe_test.py b/src/test/nn/moe/moe_test.py index 66713e407..60ec87aa7 100644 --- a/src/test/nn/moe/moe_test.py +++ b/src/test/nn/moe/moe_test.py @@ -4,7 +4,7 @@ import pytest import torch import torch.distributed as dist -from torch.distributed.tensor import Shard, distribute_tensor +from torch.distributed.tensor import Replicate, Shard, distribute_tensor from olmo_core.config import DType from olmo_core.distributed.checkpoint import ( @@ -109,7 +109,12 @@ def run_moe_with_expert_parallelism( # Split batch and expected output across process group. batch = get_local_tensor( distribute_tensor( - batch.to(device=get_default_device()), device_mesh=ep_mesh, placements=(Shard(0),) + batch.to(device=get_default_device()), + device_mesh=ep_mesh, + placements=( + Replicate(), + Shard(0), + ), ) ) batch.requires_grad_(True) @@ -117,7 +122,10 @@ def run_moe_with_expert_parallelism( distribute_tensor( expected_output.to(device=get_default_device()), device_mesh=ep_mesh, - placements=(Shard(0),), + placements=( + Replicate(), + Shard(0), + ), ) ) From 9bbbcad1c74c7bc1095d09efc25dac41027f2ed9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 16:28:52 -0800 Subject: [PATCH 119/230] okay, let's try this --- .../distributed/parallel/__init__.py | 115 +++++++++++++----- src/olmo_core/nn/moe/mlp.py | 32 +++-- src/olmo_core/nn/moe/moe.py | 21 +++- src/olmo_core/nn/moe/parallel_mlp.py | 13 +- src/olmo_core/nn/transformer/block.py | 13 +- src/olmo_core/nn/transformer/model.py | 13 +- .../train/train_module/transformer.py | 9 +- .../train_module/transformer_pipeline.py | 2 - 8 files changed, 163 insertions(+), 55 deletions(-) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 8b96e6d16..206244a44 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -24,6 +24,7 @@ "get_dp_mesh", "get_tp_mesh", "get_pp_mesh", + "get_ep_mesh", "get_dp_process_group", "DataParallelType", "DataParallelConfig", @@ -84,30 +85,27 @@ def build_device_mesh( dp: Optional[DataParallelConfig] = None, tp: Optional[TensorParallelConfig] = None, pp: Optional[PipelineParallelConfig] = None, + ep: Optional[ExpertParallelConfig] = None, device_type: Optional[str] = None, -) -> Optional[DeviceMesh]: +) -> DeviceMesh: """ Build a ``DeviceMesh`` suitable for the given parallel strategies. The resulting dimension names will be defined in :class:`MeshDimName`. .. important:: A data parallel config is required if any other parallel config is set. - - .. seealso:: - Expert parallel device meshes need to be created separately with - :func:`build_expert_parallel_mesh`. """ - if pp is None and tp is None and dp is None: - return None + device_type = device_type or get_default_device().type + dp_world_size = get_world_size() + + if pp is None and tp is None and dp is None and ep is None: + return init_device_mesh(device_type, (dp_world_size,), mesh_dim_names=(MeshDimName.dp,)) + if dp is None: raise OLMoConfigurationError( "Data parallel config is required in addition to expert/tensor/pipeline parallel configs" ) - device_type = device_type or get_default_device().type - - # Determine data parallel world size. - dp_world_size = get_world_size() if pp is not None: if pp.degree < 1 or dp_world_size % pp.degree != 0: raise OLMoConfigurationError( @@ -120,6 +118,19 @@ def build_device_mesh( f"{tp.__class__.__name__}.degree must be at least 1 and divide into the world size" ) dp_world_size //= tp.degree + if ep is not None: + if ep.degree < 1 or dp_world_size % ep.degree != 0: + raise OLMoConfigurationError( + f"{ep.__class__.__name__}.degree must be at least 1 and divide into the world size" + ) + if tp is not None: + raise OLMoConfigurationError( + "expert parallelism is mutually exclusive with tensor parallism" + ) + if pp is not None: + raise NotImplementedError( + "expert parallelism + pipeline parallelism is not implemented yet" + ) # Build up mesh dimensions. names: List[str] = [] @@ -137,16 +148,27 @@ def build_device_mesh( raise OLMoConfigurationError( f"HSDP requires DP world size ({dp_world_size}) to be divisible by 'num_replicas' ({num_replicas})" ) + shard_degree = dp_world_size // num_replicas + if ep is not None: + if ep.degree != shard_degree: + raise OLMoConfigurationError( + "expert parallelism + HSDP requires the same sharding degree" + ) + names.append(MeshDimName.dp_replicate) dims.append(num_replicas) - names.append(MeshDimName.dp_shard) - dims.append(dp_world_size // num_replicas) + dims.append(shard_degree) + elif ep is not None: + names.append(MeshDimName.ep_replicate) + dims.append(dp_world_size // ep.degree) + names.append(MeshDimName.ep_shard) + dims.append(ep.degree) else: names.append(MeshDimName.dp) dims.append(dp_world_size) - # And lastly tensor/expert parallel. + # And lastly tensor parallel. if tp is not None: names.append(MeshDimName.tp) dims.append(tp.degree) @@ -170,9 +192,6 @@ def build_expert_parallel_mesh( device_type = device_type or get_default_device().type world_size = get_world_size() - # if ep_config.degree == world_size: - # return init_device_mesh(device_type, (world_size,), mesh_dim_names=(MeshDimName.ep_shard,)) - # Build up mesh dimensions. names: List[str] = [] dims: List[int] = [] @@ -201,6 +220,8 @@ def get_dp_mesh( dim_name: str = MeshDimName.dp, replicate_dim_name: str = MeshDimName.dp_replicate, shard_dim_name: str = MeshDimName.dp_shard, + ep_replicate_dim_name: str = MeshDimName.ep_replicate, + ep_shard_dim_name: str = MeshDimName.ep_shard, ) -> Optional[DeviceMesh]: """ Get the data parallel sub-mesh associated with a ``DeviceMesh`` that was potentially @@ -223,12 +244,50 @@ def get_dp_mesh( and shard_dim_name in device_mesh.mesh_dim_names ): return device_mesh[replicate_dim_name, shard_dim_name] + elif ( + ep_replicate_dim_name in device_mesh.mesh_dim_names + and ep_shard_dim_name in device_mesh.mesh_dim_names + ): + return device_mesh[ep_replicate_dim_name, ep_shard_dim_name]._flatten( + mesh_dim_name=dim_name + ) else: raise RuntimeError( f"could not determine data parallel sub-mesh from mesh with dimensions {device_mesh.mesh_dim_names}" ) +def get_ep_mesh( + device_mesh: DeviceMesh, + *, + replicate_dim_name: str = MeshDimName.dp_replicate, + shard_dim_name: str = MeshDimName.dp_shard, + ep_replicate_dim_name: str = MeshDimName.ep_replicate, + ep_shard_dim_name: str = MeshDimName.ep_shard, +) -> DeviceMesh: + """ + Get the expert parallel sub-mesh associated with a ``DeviceMesh`` that was potentially + created from :func:`build_device_mesh()`. + """ + if device_mesh.mesh_dim_names is None: + raise RuntimeError("could not determine expert parallel sub-mesh without dimension names") + + if ( + ep_replicate_dim_name in device_mesh.mesh_dim_names + and ep_shard_dim_name in device_mesh.mesh_dim_names + ): + return device_mesh[ep_replicate_dim_name, ep_shard_dim_name] + elif ( + replicate_dim_name in device_mesh.mesh_dim_names + and shard_dim_name in device_mesh.mesh_dim_names + ): + return device_mesh[replicate_dim_name, shard_dim_name] + else: + raise RuntimeError( + f"could not determine expert parallel sub-mesh from mesh with dimensions {device_mesh.mesh_dim_names}" + ) + + def get_dp_process_group( device_mesh: Optional[DeviceMesh] = None, *, @@ -255,43 +314,37 @@ def get_dp_process_group( return dp_mesh.get_group() -def get_tp_mesh( - device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.tp -) -> Optional[DeviceMesh]: +def get_tp_mesh(device_mesh: DeviceMesh, *, dim_name: str = MeshDimName.tp) -> DeviceMesh: """ Get the tensor parallel sub-mesh associated with a ``DeviceMesh`` that was potentially created from :func:`build_device_mesh()`. :param dim_name: The name of the target mesh dimension. """ - if device_mesh is None: - return None - if device_mesh.mesh_dim_names is None: raise RuntimeError("could not determine tensor parallel sub-mesh without dimension names") if dim_name in device_mesh.mesh_dim_names: return device_mesh[dim_name] else: - return None + raise RuntimeError( + f"could not determine tensor parallel sub-mesh from mesh with dimensions {device_mesh.mesh_dim_names}" + ) -def get_pp_mesh( - device_mesh: Optional[DeviceMesh] = None, *, dim_name: str = MeshDimName.pp -) -> Optional[DeviceMesh]: +def get_pp_mesh(device_mesh: DeviceMesh, *, dim_name: str = MeshDimName.pp) -> DeviceMesh: """ Get the tensor parallel sub-mesh associated with a ``DeviceMesh`` that was potentially created from :func:`build_device_mesh()`. :param dim_name: The name of the target mesh dimension. """ - if device_mesh is None: - return None - if device_mesh.mesh_dim_names is None: raise RuntimeError("could not determine pipeline parallel sub-mesh without dimension names") if dim_name in device_mesh.mesh_dim_names: return device_mesh[dim_name] else: - return None + raise RuntimeError( + f"could not determine pipeline parallel sub-mesh from mesh with dimensions {device_mesh.mesh_dim_names}" + ) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 18ae89b07..c8560c5cf 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -1,12 +1,12 @@ import warnings -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.distributed import DeviceMesh -from torch.distributed.tensor import Replicate, Shard, distribute_tensor +from torch.distributed.tensor import Placement, Replicate, Shard, distribute_tensor from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError @@ -53,18 +53,23 @@ def scale_grad(self, w: torch.Tensor) -> torch.Tensor: return w return _scale_gradient(w, self.gradient_scale) - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep( + self, + ep_mesh: DeviceMesh, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + ): """ Apply expert parallelism. """ from torch.distributed._composable.replicate import replicate - if ep_mesh.ndim != 2: - raise RuntimeError("expert parallel mesh must be 2-dimensional") + if ep_mesh.ndim > 2: + raise RuntimeError("expert parallel mesh must be 1 or 2D") if not ep_mesh.mesh_dim_names: raise RuntimeError("expert parallel mesh must have named dimensions") - replicate_dim_name, shard_dim_name = ep_mesh.mesh_dim_names + shard_dim_name = ep_mesh.mesh_dim_names[-1] self.ep_pg = ep_mesh[shard_dim_name].get_group() num_shards = ep_mesh[shard_dim_name].size() @@ -77,12 +82,23 @@ def apply_ep(self, ep_mesh: DeviceMesh): self.num_local_experts = self.num_experts // num_shards self.gradient_scale = 1.0 / num_shards - placements = [Replicate(), Shard(0)] + placements: List[Placement] = [Shard(0)] + if ep_mesh.ndim > 1: + placements.insert(0, Replicate()) + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, placements))) # type: ignore self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, placements))) # type: ignore self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, placements))) # type: ignore - replicate(self, device_mesh=ep_mesh[replicate_dim_name]) + if ep_mesh.ndim > 1: + if compile_enabled: + if autograd_compile_enabled: + torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore + + replicate_dim_name = ep_mesh.mesh_dim_names[0] + replicate(self, device_mesh=ep_mesh[replicate_dim_name]) class MoEMLP(MoEMLPBase): diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 4aa1dae2f..8d6d974ee 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -198,24 +198,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep( + self, + ep_mesh: DeviceMesh, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + ): """ Apply expert parallelism. """ - self.experts.apply_ep(ep_mesh) + self.experts.apply_ep( + ep_mesh, + compile_enabled=compile_enabled, + autograd_compile_enabled=autograd_compile_enabled, + ) def apply_tp( self, tp_mesh: DeviceMesh, output_layouts: Optional[Placement] = None, use_local_output: bool = True, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, ): parallelize_module( self.router, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(use_local_output=True), ) - self.experts.apply_ep(tp_mesh) + self.experts.apply_ep( + tp_mesh, + compile_enabled=compile_enabled, + autograd_compile_enabled=autograd_compile_enabled, + ) parallelize_module( self, device_mesh=tp_mesh, diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index e3939d330..667369f46 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -61,11 +61,20 @@ def ep_world_size(self) -> int: def ep_pg(self) -> Optional[dist.ProcessGroup]: return self.mlp.ep_pg - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep( + self, + ep_mesh: DeviceMesh, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + ): """ Apply expert parallelism. """ - self.mlp.apply_ep(ep_mesh) + self.mlp.apply_ep( + ep_mesh, + compile_enabled=compile_enabled, + autograd_compile_enabled=autograd_compile_enabled, + ) self._expert_parallel_enabled = True def indices_and_bins( diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 1321f8dee..804fd8918 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -397,8 +397,17 @@ def forward( ) return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) - def apply_ep(self, ep_mesh: DeviceMesh): - self.feed_forward_moe.apply_ep(ep_mesh) + def apply_ep( + self, + ep_mesh: DeviceMesh, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + ): + self.feed_forward_moe.apply_ep( + ep_mesh, + compile_enabled=compile_enabled, + autograd_compile_enabled=autograd_compile_enabled, + ) def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: return Shard(1) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 1e9b0f8dc..6c074f0b3 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -672,6 +672,15 @@ def forward( return self.lm_head(h) if self.lm_head is not None else h - def apply_ep(self, ep_mesh: DeviceMesh): + def apply_ep( + self, + ep_mesh: DeviceMesh, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + ): for block in self.blocks.values(): - cast(MoETransformerBlock, block).apply_ep(ep_mesh) + cast(MoETransformerBlock, block).apply_ep( + ep_mesh, + compile_enabled=compile_enabled, + autograd_compile_enabled=autograd_compile_enabled, + ) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index e735bb796..228d2190a 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -22,9 +22,9 @@ ExpertParallelConfig, TensorParallelConfig, build_device_mesh, - build_expert_parallel_mesh, get_dp_mesh, get_dp_process_group, + get_ep_mesh, get_tp_mesh, ) from olmo_core.distributed.utils import get_local_tensor, get_world_size @@ -273,7 +273,7 @@ def __init__( self.device = device or get_default_device() self.world_mesh = build_device_mesh( - dp=dp_config, tp=tp_config, device_type=self.device.type + dp=dp_config, tp=tp_config, ep=ep_config, device_type=self.device.type ) log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") @@ -305,7 +305,6 @@ def __init__( raise NotImplementedError("TP + EP is not implemented yet") if tp_config is not None: tp_mesh = get_tp_mesh(self.world_mesh) - assert tp_mesh is not None self.model.apply_tp( tp_mesh, float8_enabled=float8_enabled, @@ -318,8 +317,8 @@ def __init__( if ep_config is not None: if not self.model.is_moe: raise OLMoConfigurationError("Expert parallelism is only valid for MoE models") - ep_mesh = build_expert_parallel_mesh(ep_config) - cast(MoETransformer, self.model).apply_ep(ep_mesh) + ep_mesh = get_ep_mesh(self.world_mesh) + cast(MoETransformer, self.model).apply_ep(ep_mesh, compile_enabled=compile_model) log.info("Applied expert parallelism to the model") # Maybe apply activation checkpointing. diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index ac20fd38a..02773ae7b 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -346,7 +346,6 @@ def __init__( self.model_parts: List[Transformer] = [] pp_mesh = get_pp_mesh(self.world_mesh) - assert pp_mesh is not None stages, model_parts = pp_config.split_model(model, pp_mesh=pp_mesh, device=self.device) self._pp_stages = stages self.model_parts = model_parts @@ -362,7 +361,6 @@ def __init__( # Maybe apply tensor parallelism. if tp_config is not None: tp_mesh = get_tp_mesh(self.world_mesh) - assert tp_mesh is not None for model in self.model_parts: model.apply_tp( tp_mesh, From b134bc9cd8979df0f7ffbc3a5058b57275d1fcae Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 16:30:02 -0800 Subject: [PATCH 120/230] fix --- src/olmo_core/nn/moe/parallel_mlp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 667369f46..30ab8b058 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -194,8 +194,17 @@ def warmup_cache(self, max_local_microbatch_size: int): device=get_default_device(), ) - def apply_ep(self, ep_mesh: DeviceMesh): - super().apply_ep(ep_mesh) + def apply_ep( + self, + ep_mesh: DeviceMesh, + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + ): + super().apply_ep( + ep_mesh, + compile_enabled=compile_enabled, + autograd_compile_enabled=autograd_compile_enabled, + ) if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) From 14748277464ed0fcc827eb51781a65c8f0f3ef54 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 16:38:38 -0800 Subject: [PATCH 121/230] require HSDP for expert parallelism --- src/olmo_core/distributed/parallel/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/olmo_core/distributed/parallel/__init__.py b/src/olmo_core/distributed/parallel/__init__.py index 206244a44..ae597f8e0 100644 --- a/src/olmo_core/distributed/parallel/__init__.py +++ b/src/olmo_core/distributed/parallel/__init__.py @@ -123,6 +123,10 @@ def build_device_mesh( raise OLMoConfigurationError( f"{ep.__class__.__name__}.degree must be at least 1 and divide into the world size" ) + if dp.name != DataParallelType.hsdp: + raise OLMoConfigurationError( + "expert parallelism can currently only be used with HSDP data parallelism" + ) if tp is not None: raise OLMoConfigurationError( "expert parallelism is mutually exclusive with tensor parallism" From e7e726e9766a9d015deb550c65a39005d5d3e457 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 16:44:28 -0800 Subject: [PATCH 122/230] fix dtype --- src/olmo_core/nn/moe/mlp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index c8560c5cf..ca35a9bcd 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -153,6 +153,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :param x: The input of shape ``(num_local_experts, N, d_model)``. """ + og_dtype = x.dtype + # Scale gradients and get local tensors (in case of expert parallelism). # shape (all): (num_local_experts, hidden_size, d_model) w1, w2, w3 = ( @@ -161,8 +163,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: get_local_tensor(self.scale_grad(self.w3)), ) + x = x.type_as(w1) + # Compute the MLP. - return torch.bmm(F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3), w2) + return torch.bmm(F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3), w2).to(dtype=og_dtype) class DroplessMoEMLP(MoEMLPBase): From 64e384d3a010f5950a2911df58b0a5e6d3cf9634 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 17:12:57 -0800 Subject: [PATCH 123/230] fewer active experts --- src/olmo_core/nn/transformer/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 40893415e..fb797e0a8 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -362,7 +362,7 @@ def smallmoe(cls, vocab_size: int, **kwargs) -> "TransformerConfig": name=MoEType.default, num_experts=32, hidden_size=int(0.5 * d_model), - router=MoERouterConfig(top_k=8, bias=False), + router=MoERouterConfig(top_k=4, bias=False), shared_mlp=SharedMLPConfig(hidden_size=d_model * 2, bias=False), lb_loss_weight=0.01, z_loss_weight=0.001, From 2a9df66b74ab7f3c3525d6e49b78527d47542c16 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 17:17:42 -0800 Subject: [PATCH 124/230] idk --- src/olmo_core/nn/moe/parallel_mlp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 30ab8b058..fbc2bfd65 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -397,7 +397,8 @@ def parallel_forward_once( x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = ops.binned_scatter(x, indices, expert_weights, bins, self.top_k) return x, batch_size_per_expert.flatten() From 557699e24177bef02fce643db12b2a7f6b651def Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 17:20:15 -0800 Subject: [PATCH 125/230] try this --- src/olmo_core/nn/moe/parallel_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index fbc2bfd65..0e1d77868 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -398,7 +398,9 @@ def parallel_forward_once( # Un-permute locally to setup for the next series of operations. # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = ops.binned_scatter(x, indices, expert_weights, bins, self.top_k) + x = ops.binned_scatter( + x.view(self.num_experts, -1, self.d_model), indices, expert_weights, bins, self.top_k + ) return x, batch_size_per_expert.flatten() From 6972b22bac97e81790632ba759b7ff9ae2d4787e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 22:15:56 -0800 Subject: [PATCH 126/230] fix? --- src/olmo_core/nn/moe/parallel_mlp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 0e1d77868..f72e5440d 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -222,11 +222,9 @@ def expert_capacity(self, local_batch_size: int) -> int: else: local_batch_size = self.max_local_microbatch_size - num_global_items = local_batch_size * self.ep_world_size - num_global_expert_inputs = self.top_k * num_global_items - inputs_per_expert = num_global_expert_inputs / self.num_experts + local_inputs_per_expert = self.top_k * local_batch_size / self.num_experts - return int(self.capacity_factor * inputs_per_expert) + return self.ep_world_size * int(self.capacity_factor * local_inputs_per_expert) @torch.no_grad() def _get_parallel_indices_and_bins( From af6a7746ee72cc87ac55280bc03f57fa8ffebed0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 22:19:38 -0800 Subject: [PATCH 127/230] clean up --- src/olmo_core/nn/moe/parallel_mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index f72e5440d..bcdaafe61 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -223,7 +223,6 @@ def expert_capacity(self, local_batch_size: int) -> int: local_batch_size = self.max_local_microbatch_size local_inputs_per_expert = self.top_k * local_batch_size / self.num_experts - return self.ep_world_size * int(self.capacity_factor * local_inputs_per_expert) @torch.no_grad() @@ -342,7 +341,7 @@ def parallel_forward_once( expert_indices = expert_indices.flatten() with torch.no_grad(): - indices, bin_ids, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) + indices, _, bins, batch_size_per_expert = self.indices_and_bins(expert_indices) # Permute locally so that the tokens for each device are stored contiguously. # shape: (num_experts, local_expert_capacity, d_model) @@ -395,7 +394,6 @@ def parallel_forward_once( x = ops.sum_tensor(x.view(self.hidden_sharding_degree, -1, self.d_model), dim=0) # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) x = ops.binned_scatter( x.view(self.num_experts, -1, self.d_model), indices, expert_weights, bins, self.top_k ) From 75dc2bb5acfc1eb7e7fa9439ab437280995c3c8c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 6 Feb 2025 22:29:23 -0800 Subject: [PATCH 128/230] custom op --- src/olmo_core/nn/moe/parallel_mlp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index bcdaafe61..d29dda88a 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -303,6 +303,9 @@ def forward_once( ) return x, batch_size_per_expert + @torch.library.custom_op( + "olmo_core::moe_parallel_forward_once", mutates_args={}, device_types="cuda" + ) def parallel_forward_once( self, x: torch.Tensor, From c7b248a43d0d3f83e0afa7de95fb0414934bbdc7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 08:48:12 -0800 Subject: [PATCH 129/230] try again --- src/olmo_core/nn/moe/ops.py | 2 ++ src/olmo_core/nn/moe/parallel_mlp.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index 831f7f3e7..b8049b7f8 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -181,6 +181,7 @@ def backward(ctx: Any, grad: torch.Tensor): return out, None, None, None, None +@torch.library.custom_op("olmo_core::moe_binned_gather", mutates_args={}, device_types="cuda") def binned_gather( x: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, bin_size: int, top_k: int ) -> torch.Tensor: @@ -237,6 +238,7 @@ def backward(ctx: Any, grad: torch.Tensor): return out, None, wgrad, None, None +@torch.library.custom_op("olmo_core::moe_binned_scatter", mutates_args={}, device_types="cuda") def binned_scatter( x: torch.Tensor, indices: torch.Tensor, diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index d29dda88a..bcdaafe61 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -303,9 +303,6 @@ def forward_once( ) return x, batch_size_per_expert - @torch.library.custom_op( - "olmo_core::moe_parallel_forward_once", mutates_args={}, device_types="cuda" - ) def parallel_forward_once( self, x: torch.Tensor, From 4fb0cfda4fe879a3af05072e62f68db8093968ab Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 08:54:35 -0800 Subject: [PATCH 130/230] try this --- src/olmo_core/nn/moe/ops.py | 2 -- src/olmo_core/nn/moe/parallel_mlp.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/ops.py b/src/olmo_core/nn/moe/ops.py index b8049b7f8..831f7f3e7 100644 --- a/src/olmo_core/nn/moe/ops.py +++ b/src/olmo_core/nn/moe/ops.py @@ -181,7 +181,6 @@ def backward(ctx: Any, grad: torch.Tensor): return out, None, None, None, None -@torch.library.custom_op("olmo_core::moe_binned_gather", mutates_args={}, device_types="cuda") def binned_gather( x: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, bin_size: int, top_k: int ) -> torch.Tensor: @@ -238,7 +237,6 @@ def backward(ctx: Any, grad: torch.Tensor): return out, None, wgrad, None, None -@torch.library.custom_op("olmo_core::moe_binned_scatter", mutates_args={}, device_types="cuda") def binned_scatter( x: torch.Tensor, indices: torch.Tensor, diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index bcdaafe61..29b9f79f2 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -303,6 +303,7 @@ def forward_once( ) return x, batch_size_per_expert + @torch._dynamo.disable() def parallel_forward_once( self, x: torch.Tensor, From 24a41a0203200cd312f9bbfc24f78cdbab5a4894 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 09:06:54 -0800 Subject: [PATCH 131/230] pre-cast to int --- src/olmo_core/nn/moe/parallel_mlp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 29b9f79f2..985bbaf7e 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -260,6 +260,7 @@ def _get_parallel_indices_and_bins( # shape: (num_local_experts * expert_capacity,) _, parallel_indices = torch.sort(parallel_top_expert) + parallel_indices = parallel_indices.int() # Calculate the bins boundaries from the token counts. # shape: (num_local_experts,) @@ -380,7 +381,7 @@ def parallel_forward_once( # Locally permute the tokens and perform the expert computation. parallel_x = self.permute_and_compute( parallel_x, - indices=parallel_indices.int(), + indices=parallel_indices, expert_weights=None, bins=parallel_bins, expert_capacity=expert_capacity, From 095f38979dcfc951e0f2a21fea6b465ff66ab572 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 09:29:59 -0800 Subject: [PATCH 132/230] debugging --- src/olmo_core/nn/moe/mlp.py | 2 +- src/olmo_core/train/train_module/transformer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index ca35a9bcd..2f65acb8b 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -145,7 +145,7 @@ def __init__( ) def extra_repr(self): - return f"d_model={self.d_model}, num_experts={self.num_experts}, hidden_size={self.hidden_size}" + return f"num_experts={self.num_experts}, in_features={self.d_model}, hidden_size={self.hidden_size}" def forward(self, x: torch.Tensor) -> torch.Tensor: """ diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 228d2190a..dc0ae4919 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -370,6 +370,8 @@ def __init__( # Build optimizer(s). log.info("Building optimizer...") self.optim: Optimizer = optim.build(self.model, strict=True) + for i, (name, param) in enumerate(self.model.named_parameters()): + log.info(f"param {i+1}: '{name}' ({tuple(param.shape)})") self.rank_microbatch_size = rank_microbatch_size self.max_sequence_length = max_sequence_length From fc1c24817cd0f0ad30fd3b45a18f37fee011ed40 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 09:35:29 -0800 Subject: [PATCH 133/230] try not flattening --- src/olmo_core/train/train_module/transformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index dc0ae4919..aeb874479 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -370,8 +370,6 @@ def __init__( # Build optimizer(s). log.info("Building optimizer...") self.optim: Optimizer = optim.build(self.model, strict=True) - for i, (name, param) in enumerate(self.model.named_parameters()): - log.info(f"param {i+1}: '{name}' ({tuple(param.shape)})") self.rank_microbatch_size = rank_microbatch_size self.max_sequence_length = max_sequence_length @@ -380,10 +378,10 @@ def __init__( self.max_grad_norm = max_grad_norm self.scheduler = scheduler self.state_dict_save_opts = state_dict_save_opts or dist_cp_sd.StateDictOptions( - flatten_optimizer_state_dict=True, cpu_offload=True + flatten_optimizer_state_dict=False, cpu_offload=True ) self.state_dict_load_opts = state_dict_load_opts or dist_cp_sd.StateDictOptions( - flatten_optimizer_state_dict=True, strict=True + flatten_optimizer_state_dict=False, strict=True ) self.load_key_mapping = load_key_mapping self.label_ignore_index = label_ignore_index From f9d70b0180f9f95f13d12d5ea036d9d98ec39f96 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 09:37:07 -0800 Subject: [PATCH 134/230] revert --- src/olmo_core/train/train_module/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index aeb874479..228d2190a 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -378,10 +378,10 @@ def __init__( self.max_grad_norm = max_grad_norm self.scheduler = scheduler self.state_dict_save_opts = state_dict_save_opts or dist_cp_sd.StateDictOptions( - flatten_optimizer_state_dict=False, cpu_offload=True + flatten_optimizer_state_dict=True, cpu_offload=True ) self.state_dict_load_opts = state_dict_load_opts or dist_cp_sd.StateDictOptions( - flatten_optimizer_state_dict=False, strict=True + flatten_optimizer_state_dict=True, strict=True ) self.load_key_mapping = load_key_mapping self.label_ignore_index = label_ignore_index From b1a423c44afd9b51c9c16726282ee4dfa9468372 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:00:56 -0800 Subject: [PATCH 135/230] try this --- src/olmo_core/nn/moe/mlp.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 2f65acb8b..c55e0711e 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -83,22 +83,22 @@ def apply_ep( self.gradient_scale = 1.0 / num_shards placements: List[Placement] = [Shard(0)] - if ep_mesh.ndim > 1: - placements.insert(0, Replicate()) - - self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, placements))) # type: ignore - self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, placements))) # type: ignore - self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, placements))) # type: ignore - - if ep_mesh.ndim > 1: - if compile_enabled: - if autograd_compile_enabled: - torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore - else: - torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore - - replicate_dim_name = ep_mesh.mesh_dim_names[0] - replicate(self, device_mesh=ep_mesh[replicate_dim_name]) + # if ep_mesh.ndim > 1: + # placements.insert(0, Replicate()) + + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh[shard_dim_name], placements))) # type: ignore + self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh[shard_dim_name], placements))) # type: ignore + self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh[shard_dim_name], placements))) # type: ignore + + # if ep_mesh.ndim > 1: + # if compile_enabled: + # if autograd_compile_enabled: + # torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore + # else: + # torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore + + # replicate_dim_name = ep_mesh.mesh_dim_names[0] + # replicate(self, device_mesh=ep_mesh[replicate_dim_name]) class MoEMLP(MoEMLPBase): From 5c62a6760739c493988d484d156278586ecc52a9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:11:36 -0800 Subject: [PATCH 136/230] let's try this --- src/olmo_core/nn/moe/mlp.py | 11 +++++++++++ src/olmo_core/nn/transformer/model.py | 6 +++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index c55e0711e..c9ba7fa60 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -100,6 +100,17 @@ def apply_ep( # replicate_dim_name = ep_mesh.mesh_dim_names[0] # replicate(self, device_mesh=ep_mesh[replicate_dim_name]) + def fully_shard(self, dp_mesh: Optional[DeviceMesh] = None, **kwargs): + from torch.distributed._composable.fsdp import fully_shard + + if self.ep_pg is None or dp_mesh is None or dp_mesh.ndim != 2: + return + + assert dp_mesh.mesh_dim_names + dim_name = dp_mesh.mesh_dim_names[0] + + fully_shard(self, dp_mesh[dim_name], **kwargs) + class MoEMLP(MoEMLPBase): """ diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 6c074f0b3..70c079e84 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -453,8 +453,12 @@ def apply_fsdp( **fsdp_config, ) else: + block = cast(MoETransformerBlock, block) + block.feed_forward_moe.experts.mlp.fully_shard( + dp_mesh, reshard_after_forward=reshard_after_forward, **fsdp_config + ) fully_shard( - block.feed_forward_moe, # type: ignore + block.feed_forward_moe, reshard_after_forward=reshard_after_forward, **fsdp_config, ) From dc0e960eda53b8fe055629a18c13f20f8fd1caec Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:14:02 -0800 Subject: [PATCH 137/230] fix --- src/olmo_core/nn/transformer/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 70c079e84..2e81bdf52 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -445,6 +445,12 @@ def apply_fsdp( # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False if pp_enabled else True + if self.is_moe: + block = cast(MoETransformerBlock, block) + block.feed_forward_moe.experts.mlp.fully_shard( + dp_mesh, reshard_after_forward=reshard_after_forward, **fsdp_config + ) + if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: if hasattr(block, "feed_forward"): fully_shard( @@ -453,12 +459,8 @@ def apply_fsdp( **fsdp_config, ) else: - block = cast(MoETransformerBlock, block) - block.feed_forward_moe.experts.mlp.fully_shard( - dp_mesh, reshard_after_forward=reshard_after_forward, **fsdp_config - ) fully_shard( - block.feed_forward_moe, + block.feed_forward_moe, # type: ignore reshard_after_forward=reshard_after_forward, **fsdp_config, ) From 5f47723239b3816a7143fc6867604f816fd1441a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:15:27 -0800 Subject: [PATCH 138/230] fix --- src/olmo_core/nn/moe/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index c9ba7fa60..b264095ce 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -109,7 +109,7 @@ def fully_shard(self, dp_mesh: Optional[DeviceMesh] = None, **kwargs): assert dp_mesh.mesh_dim_names dim_name = dp_mesh.mesh_dim_names[0] - fully_shard(self, dp_mesh[dim_name], **kwargs) + fully_shard(self, mesh=dp_mesh[dim_name], **kwargs) class MoEMLP(MoEMLPBase): From 45ec6d8df2e85496d2ae75c3947b9de7fdcd93a7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:16:39 -0800 Subject: [PATCH 139/230] fix --- src/olmo_core/nn/moe/mlp.py | 10 +++++----- src/olmo_core/nn/transformer/model.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index b264095ce..c08f0fb53 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -100,16 +100,16 @@ def apply_ep( # replicate_dim_name = ep_mesh.mesh_dim_names[0] # replicate(self, device_mesh=ep_mesh[replicate_dim_name]) - def fully_shard(self, dp_mesh: Optional[DeviceMesh] = None, **kwargs): + def fully_shard(self, mesh: Optional[DeviceMesh] = None, **kwargs): from torch.distributed._composable.fsdp import fully_shard - if self.ep_pg is None or dp_mesh is None or dp_mesh.ndim != 2: + if self.ep_pg is None or mesh is None or mesh.ndim != 2: return - assert dp_mesh.mesh_dim_names - dim_name = dp_mesh.mesh_dim_names[0] + assert mesh.mesh_dim_names + dim_name = mesh.mesh_dim_names[0] - fully_shard(self, mesh=dp_mesh[dim_name], **kwargs) + fully_shard(self, mesh=mesh[dim_name], **kwargs) class MoEMLP(MoEMLPBase): diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 2e81bdf52..a50b77461 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -448,7 +448,7 @@ def apply_fsdp( if self.is_moe: block = cast(MoETransformerBlock, block) block.feed_forward_moe.experts.mlp.fully_shard( - dp_mesh, reshard_after_forward=reshard_after_forward, **fsdp_config + reshard_after_forward=reshard_after_forward, **fsdp_config ) if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: From 0936393c66ded2ff110764acaeec146f03795942 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:53:10 -0800 Subject: [PATCH 140/230] clean up --- src/olmo_core/nn/moe/mlp.py | 30 +++++++++-------- src/olmo_core/nn/moe/moe.py | 19 +++++------ src/olmo_core/nn/moe/parallel_mlp.py | 32 ++++++------------- src/olmo_core/nn/transformer/block.py | 13 ++------ src/olmo_core/nn/transformer/model.py | 15 ++------- .../train/train_module/transformer.py | 2 +- 6 files changed, 41 insertions(+), 70 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index c08f0fb53..59f1f527b 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed import DeviceMesh -from torch.distributed.tensor import Placement, Replicate, Shard, distribute_tensor +from torch.distributed.tensor import Placement, Shard, distribute_tensor from ...distributed.utils import get_local_tensor from ...exceptions import OLMoConfigurationError @@ -46,6 +46,7 @@ def __init__( self.gradient_scale: Optional[float] = None self.num_local_experts = num_experts self.hidden_sharding_degree = 1 + self.ep_mesh: Optional[DeviceMesh] = None self.ep_pg: Optional[dist.ProcessGroup] = None def scale_grad(self, w: torch.Tensor) -> torch.Tensor: @@ -53,17 +54,10 @@ def scale_grad(self, w: torch.Tensor) -> torch.Tensor: return w return _scale_gradient(w, self.gradient_scale) - def apply_ep( - self, - ep_mesh: DeviceMesh, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, - ): + def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. """ - from torch.distributed._composable.replicate import replicate - if ep_mesh.ndim > 2: raise RuntimeError("expert parallel mesh must be 1 or 2D") if not ep_mesh.mesh_dim_names: @@ -71,6 +65,7 @@ def apply_ep( shard_dim_name = ep_mesh.mesh_dim_names[-1] + self.ep_mesh = ep_mesh self.ep_pg = ep_mesh[shard_dim_name].get_group() num_shards = ep_mesh[shard_dim_name].size() @@ -100,15 +95,24 @@ def apply_ep( # replicate_dim_name = ep_mesh.mesh_dim_names[0] # replicate(self, device_mesh=ep_mesh[replicate_dim_name]) - def fully_shard(self, mesh: Optional[DeviceMesh] = None, **kwargs): + def prepare_experts_for_fsdp(self, *, mesh: Optional[DeviceMesh] = None, **kwargs): + """ + Should be called before wrapping this module, or a parent module, with FSDP2. + + If expert parallelism is enabled over the same mesh, this will shard the local experts + over the appropriate mesh dimension. Otherwise this is a no-op. + """ from torch.distributed._composable.fsdp import fully_shard - if self.ep_pg is None or mesh is None or mesh.ndim != 2: + if mesh is None or self.mesh is None or mesh != self.mesh: return - assert mesh.mesh_dim_names - dim_name = mesh.mesh_dim_names[0] + if mesh.ndim != 2: + raise RuntimeError("expected 2D mesh!") + if mesh.mesh_dim_names is None: + raise RuntimeError("mesh must have named dimensions!") + dim_name = mesh.mesh_dim_names[0] fully_shard(self, mesh=mesh[dim_name], **kwargs) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 8d6d974ee..8ba2a25b9 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -198,20 +198,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out - def apply_ep( - self, - ep_mesh: DeviceMesh, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, - ): + def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): """ Apply expert parallelism. """ - self.experts.apply_ep( - ep_mesh, - compile_enabled=compile_enabled, - autograd_compile_enabled=autograd_compile_enabled, - ) + self.experts.apply_ep(ep_mesh, **kwargs) + + def prepare_experts_for_fsdp(self, **kwargs): + """ + Should be called before wrapping this module with FSDP2. + """ + self.experts.prepare_experts_for_fsdp(**kwargs) def apply_tp( self, diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 985bbaf7e..d84af3d6e 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -61,22 +61,19 @@ def ep_world_size(self) -> int: def ep_pg(self) -> Optional[dist.ProcessGroup]: return self.mlp.ep_pg - def apply_ep( - self, - ep_mesh: DeviceMesh, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, - ): + def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): """ Apply expert parallelism. """ - self.mlp.apply_ep( - ep_mesh, - compile_enabled=compile_enabled, - autograd_compile_enabled=autograd_compile_enabled, - ) + self.mlp.apply_ep(ep_mesh, **kwargs) self._expert_parallel_enabled = True + def prepare_experts_for_fsdp(self, **kwargs): + """ + Should be called before wrapping this module with FSDP2. + """ + self.mlp.prepare_experts_for_fsdp(**kwargs) + def indices_and_bins( self, expert_indices: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -194,17 +191,8 @@ def warmup_cache(self, max_local_microbatch_size: int): device=get_default_device(), ) - def apply_ep( - self, - ep_mesh: DeviceMesh, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, - ): - super().apply_ep( - ep_mesh, - compile_enabled=compile_enabled, - autograd_compile_enabled=autograd_compile_enabled, - ) + def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): + super().apply_ep(ep_mesh, **kwargs) if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 804fd8918..95777379e 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -397,17 +397,8 @@ def forward( ) return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) - def apply_ep( - self, - ep_mesh: DeviceMesh, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, - ): - self.feed_forward_moe.apply_ep( - ep_mesh, - compile_enabled=compile_enabled, - autograd_compile_enabled=autograd_compile_enabled, - ) + def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): + self.feed_forward_moe.apply_ep(ep_mesh, **kwargs) def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: return Shard(1) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index a50b77461..f270114ec 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -447,7 +447,7 @@ def apply_fsdp( if self.is_moe: block = cast(MoETransformerBlock, block) - block.feed_forward_moe.experts.mlp.fully_shard( + block.feed_forward_moe.prepare_experts_for_fsdp( reshard_after_forward=reshard_after_forward, **fsdp_config ) @@ -678,15 +678,6 @@ def forward( return self.lm_head(h) if self.lm_head is not None else h - def apply_ep( - self, - ep_mesh: DeviceMesh, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, - ): + def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): for block in self.blocks.values(): - cast(MoETransformerBlock, block).apply_ep( - ep_mesh, - compile_enabled=compile_enabled, - autograd_compile_enabled=autograd_compile_enabled, - ) + cast(MoETransformerBlock, block).apply_ep(ep_mesh, **kwargs) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 228d2190a..dc94a89a9 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -318,7 +318,7 @@ def __init__( if not self.model.is_moe: raise OLMoConfigurationError("Expert parallelism is only valid for MoE models") ep_mesh = get_ep_mesh(self.world_mesh) - cast(MoETransformer, self.model).apply_ep(ep_mesh, compile_enabled=compile_model) + cast(MoETransformer, self.model).apply_ep(ep_mesh) log.info("Applied expert parallelism to the model") # Maybe apply activation checkpointing. From 4e48bb2f828ec5ebbfdb0300920d0a7dea4cef79 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:58:40 -0800 Subject: [PATCH 141/230] logging --- src/olmo_core/nn/moe/mlp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 59f1f527b..84944c4c5 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -1,3 +1,4 @@ +import logging import warnings from typing import Any, Callable, List, Optional @@ -14,6 +15,9 @@ __all__ = ["MoEMLP", "DroplessMoEMLP"] +log = logging.getLogger(__name__) + + class _ScaleGradient(torch.autograd.Function): @staticmethod @torch.amp.autocast_mode.custom_fwd(device_type="cuda") @@ -64,6 +68,7 @@ def apply_ep(self, ep_mesh: DeviceMesh): raise RuntimeError("expert parallel mesh must have named dimensions") shard_dim_name = ep_mesh.mesh_dim_names[-1] + log.info(f"Splitting experts over mesh dimension '{shard_dim_name}'...") self.ep_mesh = ep_mesh self.ep_pg = ep_mesh[shard_dim_name].get_group() @@ -113,6 +118,7 @@ def prepare_experts_for_fsdp(self, *, mesh: Optional[DeviceMesh] = None, **kwarg raise RuntimeError("mesh must have named dimensions!") dim_name = mesh.mesh_dim_names[0] + log.info(f"Sharding local experts over mesh dimension '{dim_name}'...") fully_shard(self, mesh=mesh[dim_name], **kwargs) From d63e37276c7108675c8c77836583185a2092ead5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 10:59:26 -0800 Subject: [PATCH 142/230] fix --- src/olmo_core/nn/moe/mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 84944c4c5..ded650ca1 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -50,7 +50,7 @@ def __init__( self.gradient_scale: Optional[float] = None self.num_local_experts = num_experts self.hidden_sharding_degree = 1 - self.ep_mesh: Optional[DeviceMesh] = None + self.mesh: Optional[DeviceMesh] = None self.ep_pg: Optional[dist.ProcessGroup] = None def scale_grad(self, w: torch.Tensor) -> torch.Tensor: @@ -70,7 +70,7 @@ def apply_ep(self, ep_mesh: DeviceMesh): shard_dim_name = ep_mesh.mesh_dim_names[-1] log.info(f"Splitting experts over mesh dimension '{shard_dim_name}'...") - self.ep_mesh = ep_mesh + self.mesh = ep_mesh self.ep_pg = ep_mesh[shard_dim_name].get_group() num_shards = ep_mesh[shard_dim_name].size() From 99a7f094250b0e548dc4ea00235fa0113094b963 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:12:35 -0800 Subject: [PATCH 143/230] clean up --- src/olmo_core/nn/moe/mlp.py | 44 +++++++++++++++++++--------- src/olmo_core/nn/moe/moe.py | 8 +---- src/olmo_core/nn/moe/parallel_mlp.py | 7 +++++ 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index ded650ca1..5a9c214cf 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -1,6 +1,6 @@ import logging import warnings -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Literal, Optional import torch import torch.distributed as dist @@ -61,18 +61,36 @@ def scale_grad(self, w: torch.Tensor) -> torch.Tensor: def apply_ep(self, ep_mesh: DeviceMesh): """ Apply expert parallelism. + + :param ep_mesh: A 2D device mesh. + """ + self._shard_experts(ep_mesh, "ep") + + def apply_tp(self, tp_mesh: DeviceMesh): """ - if ep_mesh.ndim > 2: - raise RuntimeError("expert parallel mesh must be 1 or 2D") - if not ep_mesh.mesh_dim_names: + Apply expert parallelism. + + :param tp_mesh: A 1D device mesh. + """ + self._shard_experts(tp_mesh, "tp") + + def _shard_experts(self, mesh: DeviceMesh, flavor: Literal["tp", "ep"]): + if flavor == "ep": + if mesh.ndim != 2: + raise RuntimeError("expert parallel mesh must be 2 dimensional") + elif flavor == "tp": + if mesh.ndim != 1: + raise RuntimeError("tensor parallel mesh must be 1 dimensional") + + if not mesh.mesh_dim_names: raise RuntimeError("expert parallel mesh must have named dimensions") - shard_dim_name = ep_mesh.mesh_dim_names[-1] + shard_dim_name = mesh.mesh_dim_names[-1] log.info(f"Splitting experts over mesh dimension '{shard_dim_name}'...") - self.mesh = ep_mesh - self.ep_pg = ep_mesh[shard_dim_name].get_group() - num_shards = ep_mesh[shard_dim_name].size() + self.mesh = mesh + self.ep_pg = mesh[shard_dim_name].get_group() + num_shards = mesh[shard_dim_name].size() if self.num_experts % num_shards != 0: raise OLMoConfigurationError( @@ -83,12 +101,10 @@ def apply_ep(self, ep_mesh: DeviceMesh): self.gradient_scale = 1.0 / num_shards placements: List[Placement] = [Shard(0)] - # if ep_mesh.ndim > 1: - # placements.insert(0, Replicate()) - self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh[shard_dim_name], placements))) # type: ignore - self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh[shard_dim_name], placements))) # type: ignore - self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh[shard_dim_name], placements))) # type: ignore + self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, mesh[shard_dim_name], placements))) # type: ignore + self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, mesh[shard_dim_name], placements))) # type: ignore + self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, mesh[shard_dim_name], placements))) # type: ignore # if ep_mesh.ndim > 1: # if compile_enabled: @@ -113,7 +129,7 @@ def prepare_experts_for_fsdp(self, *, mesh: Optional[DeviceMesh] = None, **kwarg return if mesh.ndim != 2: - raise RuntimeError("expected 2D mesh!") + raise RuntimeError("expected a 2D mesh!") if mesh.mesh_dim_names is None: raise RuntimeError("mesh must have named dimensions!") diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 8ba2a25b9..5bca214b6 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -215,19 +215,13 @@ def apply_tp( tp_mesh: DeviceMesh, output_layouts: Optional[Placement] = None, use_local_output: bool = True, - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, ): parallelize_module( self.router, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(use_local_output=True), ) - self.experts.apply_ep( - tp_mesh, - compile_enabled=compile_enabled, - autograd_compile_enabled=autograd_compile_enabled, - ) + self.experts.apply_tp(tp_mesh) parallelize_module( self, device_mesh=tp_mesh, diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index d84af3d6e..30746d7b8 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -68,6 +68,13 @@ def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): self.mlp.apply_ep(ep_mesh, **kwargs) self._expert_parallel_enabled = True + def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): + """ + Apply tensor parallelism. + """ + self.mlp.apply_tp(tp_mesh, **kwargs) + self._expert_parallel_enabled = True + def prepare_experts_for_fsdp(self, **kwargs): """ Should be called before wrapping this module with FSDP2. From 2bcc0796b9b04caa8b7dccb5496559d8bb819b25 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:18:45 -0800 Subject: [PATCH 144/230] try with replicate --- src/olmo_core/nn/moe/mlp.py | 40 ++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 5a9c214cf..77aae0169 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -81,6 +81,8 @@ def _shard_experts(self, mesh: DeviceMesh, flavor: Literal["tp", "ep"]): elif flavor == "tp": if mesh.ndim != 1: raise RuntimeError("tensor parallel mesh must be 1 dimensional") + else: + raise ValueError(flavor) if not mesh.mesh_dim_names: raise RuntimeError("expert parallel mesh must have named dimensions") @@ -101,22 +103,19 @@ def _shard_experts(self, mesh: DeviceMesh, flavor: Literal["tp", "ep"]): self.gradient_scale = 1.0 / num_shards placements: List[Placement] = [Shard(0)] - self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, mesh[shard_dim_name], placements))) # type: ignore self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, mesh[shard_dim_name], placements))) # type: ignore self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, mesh[shard_dim_name], placements))) # type: ignore - # if ep_mesh.ndim > 1: - # if compile_enabled: - # if autograd_compile_enabled: - # torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore - # else: - # torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore - - # replicate_dim_name = ep_mesh.mesh_dim_names[0] - # replicate(self, device_mesh=ep_mesh[replicate_dim_name]) - - def prepare_experts_for_fsdp(self, *, mesh: Optional[DeviceMesh] = None, **kwargs): + def prepare_experts_for_fsdp( + self, + *, + mesh: Optional[DeviceMesh] = None, + strategy: Literal["replicate", "shard"] = "replicate", + compile_enabled: bool = False, + autograd_compile_enabled: bool = False, + **kwargs, + ): """ Should be called before wrapping this module, or a parent module, with FSDP2. @@ -124,6 +123,7 @@ def prepare_experts_for_fsdp(self, *, mesh: Optional[DeviceMesh] = None, **kwarg over the appropriate mesh dimension. Otherwise this is a no-op. """ from torch.distributed._composable.fsdp import fully_shard + from torch.distributed._composable.replicate import replicate if mesh is None or self.mesh is None or mesh != self.mesh: return @@ -134,8 +134,20 @@ def prepare_experts_for_fsdp(self, *, mesh: Optional[DeviceMesh] = None, **kwarg raise RuntimeError("mesh must have named dimensions!") dim_name = mesh.mesh_dim_names[0] - log.info(f"Sharding local experts over mesh dimension '{dim_name}'...") - fully_shard(self, mesh=mesh[dim_name], **kwargs) + if strategy == "shard": + log.info(f"Sharding local experts over mesh dimension '{dim_name}'...") + fully_shard(self, mesh=mesh[dim_name], **kwargs) + elif strategy == "replicate": + if compile_enabled: + if autograd_compile_enabled: + torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore + + log.info(f"Replicating local experts over mesh dimension '{dim_name}'...") + replicate(self, device_mesh=mesh[dim_name]) + else: + raise ValueError(strategy) class MoEMLP(MoEMLPBase): From b2c39b8cc5333d1043f94524c597c72e243ce200 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:22:39 -0800 Subject: [PATCH 145/230] clean up --- src/olmo_core/nn/moe/mlp.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 77aae0169..2b3f88779 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -112,8 +112,6 @@ def prepare_experts_for_fsdp( *, mesh: Optional[DeviceMesh] = None, strategy: Literal["replicate", "shard"] = "replicate", - compile_enabled: bool = False, - autograd_compile_enabled: bool = False, **kwargs, ): """ @@ -138,12 +136,7 @@ def prepare_experts_for_fsdp( log.info(f"Sharding local experts over mesh dimension '{dim_name}'...") fully_shard(self, mesh=mesh[dim_name], **kwargs) elif strategy == "replicate": - if compile_enabled: - if autograd_compile_enabled: - torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" # type: ignore - else: - torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore - + # TODO: this doesn't work yet. log.info(f"Replicating local experts over mesh dimension '{dim_name}'...") replicate(self, device_mesh=mesh[dim_name]) else: From d988ff212964f7aa5639a52031c7ab5cea6f41ea Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:23:58 -0800 Subject: [PATCH 146/230] back to sharding --- src/olmo_core/nn/moe/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 2b3f88779..33038b7d2 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -111,7 +111,7 @@ def prepare_experts_for_fsdp( self, *, mesh: Optional[DeviceMesh] = None, - strategy: Literal["replicate", "shard"] = "replicate", + strategy: Literal["replicate", "shard"] = "shard", **kwargs, ): """ From b441c45ed956af2408aa99096abdc03d8ab5e3c0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:27:06 -0800 Subject: [PATCH 147/230] clean up --- src/olmo_core/nn/moe/parallel_mlp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 30746d7b8..111abad90 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -203,6 +203,11 @@ def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) + def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): + super().apply_tp(tp_mesh, **kwargs) + if self.max_local_microbatch_size is not None: + self.warmup_cache(self.max_local_microbatch_size) + def expert_capacity(self, local_batch_size: int) -> int: # NOTE: need to ensure this is the same across the process group. # If local batch sizes are different then these will be different, and `parallel_forward_once` @@ -461,6 +466,7 @@ def forward_once( return out, batch_size_per_expert + @torch._dynamo.disable() def parallel_forward_once( self, x: torch.Tensor, From 18a007353323e87049c5ee09b2ca14ce091f4668 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:28:14 -0800 Subject: [PATCH 148/230] try dropless --- src/olmo_core/nn/transformer/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index fb797e0a8..eeb53ebdf 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -359,7 +359,7 @@ def smallmoe(cls, vocab_size: int, **kwargs) -> "TransformerConfig": rope_theta=kwargs.pop("rope_theta", 500_000), layer_norm_eps=1e-6, feed_forward_moe=MoEConfig( - name=MoEType.default, + name=MoEType.dropless, num_experts=32, hidden_size=int(0.5 * d_model), router=MoERouterConfig(top_k=4, bias=False), From 7d7f7b5805d8493bdeeef9d1ed3fc11071137a2d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 11:34:17 -0800 Subject: [PATCH 149/230] revert change to dropless --- src/olmo_core/nn/transformer/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index eeb53ebdf..fb797e0a8 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -359,7 +359,7 @@ def smallmoe(cls, vocab_size: int, **kwargs) -> "TransformerConfig": rope_theta=kwargs.pop("rope_theta", 500_000), layer_norm_eps=1e-6, feed_forward_moe=MoEConfig( - name=MoEType.dropless, + name=MoEType.default, num_experts=32, hidden_size=int(0.5 * d_model), router=MoERouterConfig(top_k=4, bias=False), From 281785d6477594b6c8dfe48ad45e0062ad124690 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 12:18:14 -0800 Subject: [PATCH 150/230] fixes --- src/olmo_core/nn/transformer/block.py | 1 + src/olmo_core/nn/transformer/model.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 95777379e..30f80069f 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -400,6 +400,7 @@ def forward( def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): self.feed_forward_moe.apply_ep(ep_mesh, **kwargs) + @property def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: return Shard(1) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index f270114ec..b46217472 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -291,23 +291,27 @@ def apply_tp( parallelize_module, ) - parallelize_module( - module=self, - device_mesh=tp_mesh, - parallelize_plan={ - "embeddings": RowwiseParallel( + if self.embeddings is not None: + parallelize_module( + self.embeddings, + device_mesh=tp_mesh, + parallelize_plan=RowwiseParallel( input_layouts=Replicate(), use_local_output=False, ), - "lm_head": PrepareModuleInput( + ) + + if self.lm_head is not None: + parallelize_module( + self.lm_head, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( # block output layouts are same as block input layouts input_layouts=cast(TransformerBlockBase, self.blocks["0"]).tp_input_layouts, desired_input_layouts=self.lm_head.tp_input_layouts, ), - }, - ) - - self.lm_head.apply_tp(tp_mesh, loss_parallel=loss_parallel) + ) + self.lm_head.apply_tp(tp_mesh, loss_parallel=loss_parallel) # Apply tensor + sequence parallelism to every transformer block. # NOTE: At the cost of model code change, we can accelerate Sequence Parallel From 4c6d48e6f9029787634f0e1952f5baf2894cc9c8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 12:28:54 -0800 Subject: [PATCH 151/230] try again --- src/olmo_core/nn/moe/moe.py | 6 +----- src/olmo_core/nn/moe/router.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 5bca214b6..6ae4b60a3 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -216,11 +216,7 @@ def apply_tp( output_layouts: Optional[Placement] = None, use_local_output: bool = True, ): - parallelize_module( - self.router, - device_mesh=tp_mesh, - parallelize_plan=SequenceParallel(use_local_output=True), - ) + self.router.apply_tp(tp_mesh) self.experts.apply_tp(tp_mesh) parallelize_module( self, diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 3b6ca0949..d08f5e2f9 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -4,9 +4,12 @@ import torch import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.tensor.parallel import parallelize_module -from ...config import Config, DType, StrEnum -from ...exceptions import OLMoConfigurationError +from olmo_core.config import Config, DType, StrEnum +from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel +from olmo_core.exceptions import OLMoConfigurationError __all__ = ["MoERouter", "MoELinearRouter", "MoERouterConfig", "MoERouterType"] @@ -204,6 +207,10 @@ def forward( return logits, scores, expert_weights, expert_indices + @abstractmethod + def apply_tp(self, tp_mesh: DeviceMesh): + raise NotImplementedError + class MoELinearRouter(MoERouter): """ @@ -225,3 +232,10 @@ def __init__( def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: return self.w_score(x.view(-1, self.d_model)) + + def apply_tp(self, tp_mesh: DeviceMesh): + parallelize_module( + self.w_score, + device_mesh=tp_mesh, + parallelize_plan=SequenceParallel(use_local_output=True), + ) From 4a0c0e27f9d125953fe6e76ab3c00bb68cc830a1 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:05:37 -0800 Subject: [PATCH 152/230] try this --- src/olmo_core/nn/moe/mlp.py | 3 ++- src/olmo_core/nn/moe/moe.py | 14 +++++++++++--- src/olmo_core/nn/moe/router.py | 13 ++++++++++--- src/olmo_core/nn/moe/shared_mlp.py | 27 +++++++++++++++++++++++++-- src/olmo_core/nn/transformer/block.py | 4 +++- 5 files changed, 51 insertions(+), 10 deletions(-) diff --git a/src/olmo_core/nn/moe/mlp.py b/src/olmo_core/nn/moe/mlp.py index 33038b7d2..f8544c877 100644 --- a/src/olmo_core/nn/moe/mlp.py +++ b/src/olmo_core/nn/moe/mlp.py @@ -66,12 +66,13 @@ def apply_ep(self, ep_mesh: DeviceMesh): """ self._shard_experts(ep_mesh, "ep") - def apply_tp(self, tp_mesh: DeviceMesh): + def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): """ Apply expert parallelism. :param tp_mesh: A 1D device mesh. """ + del float8_enabled # TODO self._shard_experts(tp_mesh, "tp") def _shard_experts(self, mesh: DeviceMesh, flavor: Literal["tp", "ep"]): diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 6ae4b60a3..65c8429c3 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -9,7 +9,6 @@ from torch.distributed.tensor.parallel import PrepareModuleOutput, parallelize_module from olmo_core.config import Config, DType, StrEnum -from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel from olmo_core.exceptions import OLMoConfigurationError from ..buffer_cache import BufferCache @@ -215,9 +214,18 @@ def apply_tp( tp_mesh: DeviceMesh, output_layouts: Optional[Placement] = None, use_local_output: bool = True, + float8_enabled: bool = False, ): - self.router.apply_tp(tp_mesh) - self.experts.apply_tp(tp_mesh) + # Sequence parallel + self.router.apply_tp(tp_mesh, float8_enabled=float8_enabled) + + # Expert parallel + self.experts.apply_tp(tp_mesh, float8_enabled=float8_enabled) + + # Sequence parallel + if self.shared_experts is not None: + self.shared_experts.apply_tp(tp_mesh, float8_enabled=float8_enabled) + parallelize_module( self, device_mesh=tp_mesh, diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index d08f5e2f9..6c965d1f1 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed.tensor.parallel import parallelize_module +from torch.distributed.tensor import Shard +from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module from olmo_core.config import Config, DType, StrEnum from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel @@ -208,7 +209,7 @@ def forward( return logits, scores, expert_weights, expert_indices @abstractmethod - def apply_tp(self, tp_mesh: DeviceMesh): + def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): raise NotImplementedError @@ -233,9 +234,15 @@ def __init__( def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: return self.w_score(x.view(-1, self.d_model)) - def apply_tp(self, tp_mesh: DeviceMesh): + def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): + del float8_enabled parallelize_module( self.w_score, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(use_local_output=True), ) + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput(desired_input_layouts=(Shard(1),)), + ) diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py index 9cff98ef9..dec25386e 100644 --- a/src/olmo_core/nn/moe/shared_mlp.py +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -3,9 +3,14 @@ import torch import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.tensor import Shard +from torch.distributed.tensor.parallel import PrepareModuleOutput, parallelize_module + +from olmo_core.config import Config, DType, StrEnum +from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel +from olmo_core.exceptions import OLMoConfigurationError -from ...config import Config, DType, StrEnum -from ...exceptions import OLMoConfigurationError from ..feed_forward import FeedForward __all__ = ["SharedMLP", "SharedMLPConfig", "SharedMLPType"] @@ -120,3 +125,21 @@ def forward(self, x: torch.Tensor, experts_out: torch.Tensor, top_k: int) -> tor else: shared_out.add_(experts_out) return shared_out + + def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): + # Alternatively could do colwise->rowwise->colwise parallelism + del float8_enabled + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan={ + "mlp.w1": SequenceParallel(), + "mlp.w2": SequenceParallel(), + "mlp.w3": SequenceParallel(), + "mlp": PrepareModuleOutput( + output_layouts=(Shard(1),), + desired_output_layouts=(Shard(1),), + use_local_output=True, + ), + }, + ) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 30f80069f..b842eef42 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -424,7 +424,9 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): ) self.attention.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) - self.feed_forward_moe.apply_tp(tp_mesh, output_layouts=Shard(1), use_local_output=False) + self.feed_forward_moe.apply_tp( + tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + ) class MoEReorderedNormTransformerBlock(MoETransformerBlock): From c66bede0d3c771def6064b7915e7ce95a60c5e18 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:10:25 -0800 Subject: [PATCH 153/230] idk --- src/olmo_core/nn/moe/router.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 6c965d1f1..8a917c84f 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -237,12 +237,12 @@ def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): del float8_enabled parallelize_module( - self.w_score, + self, device_mesh=tp_mesh, - parallelize_plan=SequenceParallel(use_local_output=True), + parallelize_plan=PrepareModuleInput(desired_input_layouts=(Shard(1),)), ) parallelize_module( - self, + self.w_score, device_mesh=tp_mesh, - parallelize_plan=PrepareModuleInput(desired_input_layouts=(Shard(1),)), + parallelize_plan=SequenceParallel(use_local_output=True), ) From a2b8a1bc7ef0c411b61c5a77a66e6d0885d9acd8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:15:24 -0800 Subject: [PATCH 154/230] fix? --- src/olmo_core/nn/transformer/block.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index b842eef42..98d30765a 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -228,8 +228,12 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): parallelize_plan=plan, ) - self.attention.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) - self.feed_forward.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) + self.attention.apply_tp( + tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + ) + self.feed_forward.apply_tp( + tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + ) class ReorderedNormTransformerBlock(TransformerBlock): From ea2bb5ca577a6504015518efbff2775668356db7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:21:26 -0800 Subject: [PATCH 155/230] idk --- src/olmo_core/nn/transformer/block.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 98d30765a..55e1cd868 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -418,6 +418,10 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): desired_input_layouts=(Replicate(),), ), "feed_forward_norm": SequenceParallel(), + "feed_forward_moe": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Shard(1),), + ), } if isinstance(self.dropout, nn.Dropout): plan["dropout"] = SequenceParallel() @@ -427,7 +431,9 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): parallelize_plan=plan, ) - self.attention.apply_tp(tp_mesh, output_layouts=Shard(1), float8_enabled=float8_enabled) + self.attention.apply_tp( + tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + ) self.feed_forward_moe.apply_tp( tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled ) From f8d594075511dfa17d1b685598875bde58a2093c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:24:58 -0800 Subject: [PATCH 156/230] debugging --- src/olmo_core/nn/moe/moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 65c8429c3..5de74f66c 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -1,3 +1,4 @@ +import logging from abc import abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Optional, Union @@ -21,6 +22,9 @@ __all__ = ["MoEBase", "MoE", "DroplessMoE", "MoEConfig", "MoEType"] +log = logging.getLogger(__name__) + + class MoEType(StrEnum): """ An enumeration of the different MoE implementations. @@ -181,8 +185,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ - expert_logits, expert_scores, expert_weights, exper_indices = self.router(x) - out, batch_size_per_expert = self.experts(x, expert_weights, exper_indices) + log.info(f"{x=}") + expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) + log.info(f"{expert_logits=}, {expert_scores=}, {expert_weights=}, {expert_indices=}") + out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: out = self.shared_experts(x, out, self.router.top_k) From 406d5ed1b63ba95b4388690c35ff053c20de8491 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:29:44 -0800 Subject: [PATCH 157/230] make input local --- src/olmo_core/nn/moe/moe.py | 4 ++-- src/olmo_core/nn/moe/router.py | 5 ++++- src/olmo_core/nn/transformer/block.py | 5 +++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 5de74f66c..3fc3ea85d 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -185,9 +185,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ - log.info(f"{x=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) - log.info(f"{expert_logits=}, {expert_scores=}, {expert_weights=}, {expert_indices=}") out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: out = self.shared_experts(x, out, self.router.top_k) @@ -222,6 +220,8 @@ def apply_tp( use_local_output: bool = True, float8_enabled: bool = False, ): + # Input layouts assumed to be (Shard(1),) + # Sequence parallel self.router.apply_tp(tp_mesh, float8_enabled=float8_enabled) diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 8a917c84f..a83cd59d3 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -239,7 +239,10 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): parallelize_module( self, device_mesh=tp_mesh, - parallelize_plan=PrepareModuleInput(desired_input_layouts=(Shard(1),)), + parallelize_plan=PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Shard(1),), + ), ) parallelize_module( self.w_score, diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 55e1cd868..39b897346 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.tensor import Placement, Replicate, Shard -from torch.distributed.tensor.parallel import parallelize_module +from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module from olmo_core.config import Config, StrEnum from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel @@ -418,9 +418,10 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): desired_input_layouts=(Replicate(),), ), "feed_forward_norm": SequenceParallel(), - "feed_forward_moe": prepare_module_input( + "feed_forward_moe": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Shard(1),), + use_local_output=True, ), } if isinstance(self.dropout, nn.Dropout): From 5928f0e5c2c1e5dfc29ec2c22532c81eb89f5844 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:31:07 -0800 Subject: [PATCH 158/230] debug --- src/olmo_core/nn/moe/moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 3fc3ea85d..82950dff8 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -185,6 +185,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ + log.info(f"{x=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: From 3c65c0c79735a5366a70e3d52afb74009e00b8ec Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:34:54 -0800 Subject: [PATCH 159/230] debug --- src/olmo_core/nn/moe/moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 82950dff8..56faa8029 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -185,9 +185,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ - log.info(f"{x=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) + log.info(f"{expert_indices=}") + out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) + if self.shared_experts is not None: out = self.shared_experts(x, out, self.router.top_k) From 0c4b132efec1058455e0276381a0943528f9c740 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:36:18 -0800 Subject: [PATCH 160/230] debug --- src/olmo_core/nn/moe/moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 56faa8029..e9fe8e58a 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -186,8 +186,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: router Z-loss. """ expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) - log.info(f"{expert_indices=}") + log.info(f"{x=}") + log.info(f"{expert_weights=}") + log.info(f"{expert_indices=}") out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: From a829d734e5e17a19cba8d60f8a8cd485686451da Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:46:21 -0800 Subject: [PATCH 161/230] try this --- src/olmo_core/nn/moe/parallel_mlp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 111abad90..a301d67a3 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -182,6 +182,7 @@ def __init__( ): super().__init__(mlp=mlp, top_k=top_k, cache=cache) self.capacity_factor = capacity_factor + self.tp_degree: int = 1 self.max_local_microbatch_size = max_local_microbatch_size if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) @@ -205,6 +206,7 @@ def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): super().apply_tp(tp_mesh, **kwargs) + self.tp_degree = tp_mesh.size() if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) @@ -214,13 +216,14 @@ def expert_capacity(self, local_batch_size: int) -> int: # will break. This shouldn't be a problem with our trainer, but would be an issue for inference. # To avoid that you could set `self.max_local_microbatch_size` up-front. if self.max_local_microbatch_size is not None: - if local_batch_size > self.max_local_microbatch_size: + max_local_microbatch_size = self.max_local_microbatch_size // self.tp_degree + if local_batch_size > max_local_microbatch_size: raise RuntimeError( f"Local batch size ({local_batch_size:,d}) bigger than " - f"configured max local batch size ({self.max_local_microbatch_size:,d})" + f"configured max local batch size ({max_local_microbatch_size:,d})" ) else: - local_batch_size = self.max_local_microbatch_size + local_batch_size = max_local_microbatch_size local_inputs_per_expert = self.top_k * local_batch_size / self.num_experts return self.ep_world_size * int(self.capacity_factor * local_inputs_per_expert) From f247d89ce95c6038f97e8620889ab7463e1081a9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:48:02 -0800 Subject: [PATCH 162/230] debug --- src/olmo_core/nn/moe/parallel_mlp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index a301d67a3..c19bfcc79 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -206,7 +206,9 @@ def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): super().apply_tp(tp_mesh, **kwargs) + print(tp_mesh) self.tp_degree = tp_mesh.size() + print(self.tp_degree) if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) From 78722ca1d16594ae3d0d7f9aafd7e6cff1e411f0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:50:41 -0800 Subject: [PATCH 163/230] try again --- src/olmo_core/nn/moe/parallel_mlp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index c19bfcc79..445c55373 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -206,9 +206,7 @@ def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): super().apply_tp(tp_mesh, **kwargs) - print(tp_mesh) self.tp_degree = tp_mesh.size() - print(self.tp_degree) if self.max_local_microbatch_size is not None: self.warmup_cache(self.max_local_microbatch_size) @@ -338,8 +336,7 @@ def parallel_forward_once( # shape: (N, d_model) x = x.view(-1, x.shape[-1]) - num_items, _ = expert_weights.shape - expert_capacity = self.expert_capacity(num_items) + expert_capacity = self.expert_capacity(x.shape[0]) local_expert_capacity = expert_capacity // self.ep_world_size # shape: (batch_size * top_k,) From 06d88e33f64a04c7cf6a978524ebcfb86fa8ba3b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:51:45 -0800 Subject: [PATCH 164/230] assert --- src/olmo_core/nn/moe/parallel_mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 445c55373..a9c62bd3c 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -211,6 +211,7 @@ def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): self.warmup_cache(self.max_local_microbatch_size) def expert_capacity(self, local_batch_size: int) -> int: + assert isinstance(local_batch_size, int) # NOTE: need to ensure this is the same across the process group. # If local batch sizes are different then these will be different, and `parallel_forward_once` # will break. This shouldn't be a problem with our trainer, but would be an issue for inference. From 9ceb128a6a58ac887a2e7827d62adf527e854932 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:56:55 -0800 Subject: [PATCH 165/230] fix that --- src/olmo_core/nn/moe/moe.py | 3 --- src/olmo_core/nn/moe/parallel_mlp.py | 7 +++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index e9fe8e58a..84d93f8fa 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -187,9 +187,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) - log.info(f"{x=}") - log.info(f"{expert_weights=}") - log.info(f"{expert_indices=}") out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index a9c62bd3c..991a2ca31 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -191,7 +191,7 @@ def warmup_cache(self, max_local_microbatch_size: int): self.max_local_microbatch_size = max_local_microbatch_size # TODO: call `_get_parallel_indices_and_bins()` up-front to warm the cache so # torch.compile() doesn't try to trace that. - expert_capacity = self.expert_capacity(self.max_local_microbatch_size) + expert_capacity = self.expert_capacity(self.max_local_microbatch_size // self.tp_degree) local_expert_capacity = expert_capacity // self.ep_world_size self._get_parallel_indices_and_bins( expert_capacity=expert_capacity, @@ -211,7 +211,6 @@ def apply_tp(self, tp_mesh: DeviceMesh, **kwargs): self.warmup_cache(self.max_local_microbatch_size) def expert_capacity(self, local_batch_size: int) -> int: - assert isinstance(local_batch_size, int) # NOTE: need to ensure this is the same across the process group. # If local batch sizes are different then these will be different, and `parallel_forward_once` # will break. This shouldn't be a problem with our trainer, but would be an issue for inference. @@ -220,8 +219,8 @@ def expert_capacity(self, local_batch_size: int) -> int: max_local_microbatch_size = self.max_local_microbatch_size // self.tp_degree if local_batch_size > max_local_microbatch_size: raise RuntimeError( - f"Local batch size ({local_batch_size:,d}) bigger than " - f"configured max local batch size ({max_local_microbatch_size:,d})" + f"Local batch size ({local_batch_size:d}) bigger than " + f"configured max local batch size ({max_local_microbatch_size:d})" ) else: local_batch_size = max_local_microbatch_size From cead1750e16465b655f2a5dad532efb635171a3d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 13:58:33 -0800 Subject: [PATCH 166/230] debug --- src/olmo_core/nn/moe/moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 84d93f8fa..8e619a104 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -185,8 +185,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ + print(f"{x.shape=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) + print(f"{expert_indices.shape=}") out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: From 794bc032420b9156e71ba24e29945f2c41603729 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 14:02:59 -0800 Subject: [PATCH 167/230] try this --- src/olmo_core/nn/moe/router.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index a83cd59d3..8a59680a1 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -239,6 +239,15 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): parallelize_module( self, device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Shard(1),), + use_local_output=True, + ), + ) + parallelize_module( + self.w_score, + device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Shard(1),), From fb415b6c2144b27767341e92feb26067ac9cec3d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 14:07:45 -0800 Subject: [PATCH 168/230] maybe fix --- src/olmo_core/nn/moe/router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/moe/router.py b/src/olmo_core/nn/moe/router.py index 8a59680a1..e9f07a774 100644 --- a/src/olmo_core/nn/moe/router.py +++ b/src/olmo_core/nn/moe/router.py @@ -185,7 +185,7 @@ def forward( x = self.jitter(x) # shape: (batch_size * seq_len, num_experts) - logits = self.get_expert_logits(x.view(-1, self.d_model)) + logits = self.get_expert_logits(x).view(-1, self.num_experts) # shape: (batch_size * seq_len, num_experts) scores = logits.softmax(dim=-1) @@ -232,7 +232,7 @@ def __init__( ) def get_expert_logits(self, x: torch.Tensor) -> torch.Tensor: - return self.w_score(x.view(-1, self.d_model)) + return self.w_score(x) def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): del float8_enabled From 76c322a544ecdf4d4166f88436090c908992bd2b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 14:09:55 -0800 Subject: [PATCH 169/230] remove inplace op --- src/olmo_core/nn/moe/shared_mlp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/nn/moe/shared_mlp.py b/src/olmo_core/nn/moe/shared_mlp.py index dec25386e..4e46a757d 100644 --- a/src/olmo_core/nn/moe/shared_mlp.py +++ b/src/olmo_core/nn/moe/shared_mlp.py @@ -120,11 +120,10 @@ def forward(self, x: torch.Tensor, experts_out: torch.Tensor, top_k: int) -> tor if self.weighted_sum: # Weighted by number of experts used n_active_experts = top_k + 1 - shared_out.div_(n_active_experts) - shared_out.add_(experts_out, alpha=top_k / n_active_experts) + shared_out = shared_out / n_active_experts + return shared_out.add(experts_out, alpha=top_k / n_active_experts) else: - shared_out.add_(experts_out) - return shared_out + return shared_out + experts_out def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): # Alternatively could do colwise->rowwise->colwise parallelism From d9d9d3325042319be341f4c51f81dc7fdc5d686f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 14:10:44 -0800 Subject: [PATCH 170/230] clean up --- src/olmo_core/nn/moe/moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 8e619a104..84d93f8fa 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -185,10 +185,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ - print(f"{x.shape=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) - print(f"{expert_indices.shape=}") out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) if self.shared_experts is not None: From 0ea34ed2582a7c6a6f7676ebcf5c4225387e4684 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:18:31 -0800 Subject: [PATCH 171/230] clean up tensor parallel --- src/olmo_core/nn/attention.py | 30 +++++--- src/olmo_core/nn/feed_forward.py | 25 +++++-- src/olmo_core/nn/lm_head.py | 54 +++++++++----- src/olmo_core/nn/moe/moe.py | 20 ++++-- src/olmo_core/nn/transformer/block.py | 100 ++++++++++---------------- src/olmo_core/nn/transformer/model.py | 31 +++----- 6 files changed, 138 insertions(+), 122 deletions(-) diff --git a/src/olmo_core/nn/attention.py b/src/olmo_core/nn/attention.py index 67210e01b..140efa4ea 100644 --- a/src/olmo_core/nn/attention.py +++ b/src/olmo_core/nn/attention.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed import DeviceMesh -from torch.distributed.tensor import Placement, Shard +from torch.distributed.tensor import Placement, Replicate, Shard from torch.distributed.tensor.parallel import parallelize_module from ..config import Config, DType, StrEnum @@ -354,11 +354,23 @@ def forward( def apply_tp( self, tp_mesh: DeviceMesh, - output_layouts: Optional[Placement] = None, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): - rowwise_parallel, colwise_parallel, _ = get_tp_wrappers(float8_enabled=float8_enabled) + rowwise_parallel, colwise_parallel, prepare_module_input = get_tp_wrappers( + float8_enabled=float8_enabled + ) + + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=prepare_module_input( + input_layouts=None if input_layout is None else (input_layout,), + desired_input_layouts=(Replicate(),), + ), + ) plan = { "w_q": colwise_parallel( @@ -371,7 +383,7 @@ def apply_tp( ), "w_v": colwise_parallel(), "w_out": rowwise_parallel( - output_layouts=output_layouts, use_local_output=use_local_output + output_layouts=output_layout, use_local_output=use_local_output ), } if self.q_norm is not None: @@ -486,11 +498,12 @@ def forward( def apply_tp( self, tp_mesh: DeviceMesh, - output_layouts: Optional[Placement] = None, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): - del tp_mesh, output_layouts, use_local_output, float8_enabled + del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled raise NotImplementedError("TP is not implemented yet for the normalized attention variant") @@ -622,11 +635,12 @@ def forward( def apply_tp( self, tp_mesh: DeviceMesh, - output_layouts: Optional[Placement] = None, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): - del tp_mesh, output_layouts, use_local_output, float8_enabled + del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled raise NotImplementedError("TP is not implemented yet for the fused attention variant") diff --git a/src/olmo_core/nn/feed_forward.py b/src/olmo_core/nn/feed_forward.py index 887d2c82e..8dafe2376 100644 --- a/src/olmo_core/nn/feed_forward.py +++ b/src/olmo_core/nn/feed_forward.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch.distributed import DeviceMesh from torch.distributed.tensor.parallel import parallelize_module -from torch.distributed.tensor.placement_types import Placement +from torch.distributed.tensor.placement_types import Placement, Replicate from ..config import Config, DType, StrEnum from ..doc_utils import beta_feature @@ -124,11 +124,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def apply_tp( self, tp_mesh: DeviceMesh, - output_layouts: Optional[Placement] = None, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): - rowwise_parallel, colwise_parallel, _ = get_tp_wrappers(float8_enabled=float8_enabled) + rowwise_parallel, colwise_parallel, prepare_module_input = get_tp_wrappers( + float8_enabled=float8_enabled + ) + + parallelize_module( + module=self, + device_mesh=tp_mesh, + parallelize_plan=prepare_module_input( + input_layouts=None if input_layout is None else (input_layout,), + desired_input_layouts=(Replicate(),), + ), + ) parallelize_module( module=self, @@ -136,7 +148,7 @@ def apply_tp( parallelize_plan={ "w1": colwise_parallel(), "w2": rowwise_parallel( - output_layouts=output_layouts, use_local_output=use_local_output + output_layouts=output_layout, use_local_output=use_local_output ), "w3": colwise_parallel(), }, @@ -188,11 +200,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def apply_tp( self, tp_mesh: DeviceMesh, - output_layouts: Optional[Placement] = None, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): - del tp_mesh, output_layouts, use_local_output, float8_enabled + del tp_mesh, input_layout, output_layout, use_local_output, float8_enabled raise NotImplementedError( "TP is not implemented yet for the normalized feed-forward variant" diff --git a/src/olmo_core/nn/lm_head.py b/src/olmo_core/nn/lm_head.py index 2da8e01e5..52cf1f249 100644 --- a/src/olmo_core/nn/lm_head.py +++ b/src/olmo_core/nn/lm_head.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Optional import torch import torch.nn as nn @@ -8,7 +8,7 @@ from torch.distributed.tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, - ParallelStyle, + PrepareModuleInput, SequenceParallel, parallelize_module, ) @@ -132,24 +132,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm(x) if self.norm is not None else x return self.w_out(h) - @property - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - return Shard(1) if self.norm is not None else Replicate() - - def apply_tp(self, tp_mesh: DeviceMesh, loss_parallel: bool = False): - tp_plan: Dict[str, ParallelStyle] = { - "w_out": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, + def apply_tp( + self, + tp_mesh: DeviceMesh, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + parallelize_module( + module=self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + input_layouts=None if input_layout is None else (input_layout,), + desired_input_layouts=(Shard(1) if self.norm is not None else Replicate(),), ), - } + ) + if self.norm is not None: - tp_plan["norm"] = SequenceParallel() + parallelize_module( + module=self, + device_mesh=tp_mesh, + parallelize_plan=SequenceParallel(), + ) + parallelize_module( - module=self, + module=self.w_out, device_mesh=tp_mesh, - parallelize_plan=tp_plan, + parallelize_plan=ColwiseParallel( + output_layouts=output_layout, + use_local_output=use_local_output, + ), ) @@ -192,8 +204,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: sz = self.sz * (self.sz_init_value / self.sz_init_scaling) return sz * self.w_out(x) - def apply_tp(self, tp_mesh: DeviceMesh, loss_parallel: bool = False): - del tp_mesh, loss_parallel + def apply_tp( + self, + tp_mesh: DeviceMesh, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + del tp_mesh, input_layout, output_layout, use_local_output raise NotImplementedError("TP is not implemented yet for the normalized LM head variant") diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 84d93f8fa..066de8772 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -7,7 +7,11 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.tensor import Placement, Replicate, Shard -from torch.distributed.tensor.parallel import PrepareModuleOutput, parallelize_module +from torch.distributed.tensor.parallel import ( + PrepareModuleInput, + PrepareModuleOutput, + parallelize_module, +) from olmo_core.config import Config, DType, StrEnum from olmo_core.exceptions import OLMoConfigurationError @@ -218,11 +222,19 @@ def prepare_experts_for_fsdp(self, **kwargs): def apply_tp( self, tp_mesh: DeviceMesh, - output_layouts: Optional[Placement] = None, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, use_local_output: bool = True, float8_enabled: bool = False, ): - # Input layouts assumed to be (Shard(1),) + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + input_layouts=None if input_layout is None else (input_layout,), + desired_input_layouts=(Shard(1),), + ), + ) # Sequence parallel self.router.apply_tp(tp_mesh, float8_enabled=float8_enabled) @@ -239,7 +251,7 @@ def apply_tp( device_mesh=tp_mesh, parallelize_plan=PrepareModuleOutput( output_layouts=(Shard(1),), - desired_output_layouts=(output_layouts or Replicate(),), + desired_output_layouts=(output_layout or Replicate(),), use_local_output=use_local_output, ), ) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 39b897346..5da24ea6d 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -1,13 +1,13 @@ import math from abc import abstractmethod from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed.tensor import Placement, Replicate, Shard -from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module +from torch.distributed.tensor import Shard +from torch.distributed.tensor.parallel import parallelize_module from olmo_core.config import Config, StrEnum from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel @@ -20,7 +20,6 @@ from ..functional import l2_normalize from ..layer_norm import LayerNormConfig from ..moe import MoEConfig -from ..utils import get_tp_wrappers class TransformerBlockType(StrEnum): @@ -150,11 +149,6 @@ def forward( def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): raise NotImplementedError - @property - @abstractmethod - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - raise NotImplementedError - class TransformerBlock(TransformerBlockBase): """ @@ -201,40 +195,32 @@ def forward( ) return h + self.dropout(self.feed_forward(self.feed_forward_norm(h))) - @property - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - return Shard(1) - def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): - _, _, prepare_module_input = get_tp_wrappers(float8_enabled=float8_enabled) - - plan = { - "attention_norm": SequenceParallel(), - "attention": prepare_module_input( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward_norm": SequenceParallel(), - "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - } - if isinstance(self.dropout, nn.Dropout): - plan["dropout"] = SequenceParallel() parallelize_module( - module=self, - device_mesh=tp_mesh, - parallelize_plan=plan, + self.attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) self.attention.apply_tp( - tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + tp_mesh, + input_layout=Shard(1), + output_layout=Shard(1), + use_local_output=False, + float8_enabled=float8_enabled, ) + + parallelize_module( + self.feed_forward_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() + ) + self.feed_forward.apply_tp( - tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + tp_mesh, + output_layout=Shard(1), + use_local_output=False, + float8_enabled=float8_enabled, ) + parallelize_module(self.dropout, device_mesh=tp_mesh, parallelize_plan=SequenceParallel()) + class ReorderedNormTransformerBlock(TransformerBlock): """ @@ -329,9 +315,6 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): "TP is not implemented yet for the normalized transformer block variant" ) - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - raise NotImplementedError - @torch.no_grad() def normalize_matrices(self): """ @@ -404,41 +387,32 @@ def forward( def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): self.feed_forward_moe.apply_ep(ep_mesh, **kwargs) - @property - def tp_input_layouts(self) -> Union[Placement, Tuple[Placement, ...]]: - return Shard(1) - def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): - _, _, prepare_module_input = get_tp_wrappers(float8_enabled=float8_enabled) - - plan = { - "attention_norm": SequenceParallel(), - "attention": prepare_module_input( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward_norm": SequenceParallel(), - "feed_forward_moe": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Shard(1),), - use_local_output=True, - ), - } - if isinstance(self.dropout, nn.Dropout): - plan["dropout"] = SequenceParallel() parallelize_module( - module=self, - device_mesh=tp_mesh, - parallelize_plan=plan, + self.attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) self.attention.apply_tp( - tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + tp_mesh, + input_layout=Shard(1), + output_layout=Shard(1), + use_local_output=False, + float8_enabled=float8_enabled, + ) + + parallelize_module( + self.feed_forward_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) + self.feed_forward_moe.apply_tp( - tp_mesh, output_layouts=Shard(1), use_local_output=False, float8_enabled=float8_enabled + tp_mesh, + output_layout=Shard(1), + use_local_output=False, + float8_enabled=float8_enabled, ) + parallelize_module(self.dropout, device_mesh=tp_mesh, parallelize_plan=SequenceParallel()) + class MoEReorderedNormTransformerBlock(MoETransformerBlock): """ diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index b46217472..063c6fc94 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import RowwiseParallel, parallelize_module from olmo_core.config import StrEnum from olmo_core.data.utils import get_cumulative_document_lengths @@ -284,13 +286,6 @@ def apply_tp( :param loss_parallel: Set to ``True`` if parallelizing the loss function as well. :param float8_enabled: Set this to ``True`` if training with float8 linear layers. """ - from torch.distributed.tensor import Replicate - from torch.distributed.tensor.parallel import ( - PrepareModuleInput, - RowwiseParallel, - parallelize_module, - ) - if self.embeddings is not None: parallelize_module( self.embeddings, @@ -301,18 +296,6 @@ def apply_tp( ), ) - if self.lm_head is not None: - parallelize_module( - self.lm_head, - device_mesh=tp_mesh, - parallelize_plan=PrepareModuleInput( - # block output layouts are same as block input layouts - input_layouts=cast(TransformerBlockBase, self.blocks["0"]).tp_input_layouts, - desired_input_layouts=self.lm_head.tp_input_layouts, - ), - ) - self.lm_head.apply_tp(tp_mesh, loss_parallel=loss_parallel) - # Apply tensor + sequence parallelism to every transformer block. # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. @@ -320,10 +303,12 @@ def apply_tp( for block in self.blocks.values(): block = cast(TransformerBlockBase, block) block.apply_tp(tp_mesh, float8_enabled=float8_enabled) - parallelize_module( - block, - device_mesh=tp_mesh, - parallelize_plan=PrepareModuleInput(desired_input_layouts=block.tp_input_layouts), + + if self.lm_head is not None: + self.lm_head.apply_tp( + tp_mesh, + output_layout=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, ) def apply_activation_checkpointing( From ae24c2af5711db3f2a082b41189eeef46b3db09b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:20:28 -0800 Subject: [PATCH 172/230] fix --- src/olmo_core/nn/lm_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/lm_head.py b/src/olmo_core/nn/lm_head.py index 52cf1f249..010d970fc 100644 --- a/src/olmo_core/nn/lm_head.py +++ b/src/olmo_core/nn/lm_head.py @@ -150,7 +150,7 @@ def apply_tp( if self.norm is not None: parallelize_module( - module=self, + module=self.norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel(), ) From da04166544977f949fb4f4a0e788abee0680e72a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:23:01 -0800 Subject: [PATCH 173/230] try this? --- src/olmo_core/nn/transformer/block.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 5da24ea6d..0fd71c88e 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -406,6 +406,7 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): self.feed_forward_moe.apply_tp( tp_mesh, + input_layout=Shard(1), output_layout=Shard(1), use_local_output=False, float8_enabled=float8_enabled, From 728121e98c2cc4e62109be70177fc186c66c3c94 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:26:10 -0800 Subject: [PATCH 174/230] fix --- src/olmo_core/nn/transformer/block.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 0fd71c88e..3b3c4e685 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -10,7 +10,10 @@ from torch.distributed.tensor.parallel import parallelize_module from olmo_core.config import Config, StrEnum -from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel +from olmo_core.distributed.parallel.tensor_parallel import ( + PrepareModuleInput, + SequenceParallel, +) from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError @@ -196,6 +199,14 @@ def forward( return h + self.dropout(self.feed_forward(self.feed_forward_norm(h))) def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + desired_input_layouts=(Shard(1),), + ), + ) + parallelize_module( self.attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) @@ -388,6 +399,14 @@ def apply_ep(self, ep_mesh: DeviceMesh, **kwargs): self.feed_forward_moe.apply_ep(ep_mesh, **kwargs) def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: bool = False): + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + desired_input_layouts=(Shard(1),), + ), + ) + parallelize_module( self.attention_norm, device_mesh=tp_mesh, parallelize_plan=SequenceParallel() ) From d613fea9953b322ec8e53a4a43dbd25bc104ce96 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:26:43 -0800 Subject: [PATCH 175/230] ooops --- src/olmo_core/nn/transformer/block.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 3b3c4e685..503dd10dc 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -7,13 +7,10 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.tensor import Shard -from torch.distributed.tensor.parallel import parallelize_module +from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module from olmo_core.config import Config, StrEnum -from olmo_core.distributed.parallel.tensor_parallel import ( - PrepareModuleInput, - SequenceParallel, -) +from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError From dc4a8a8caf625238bc07b008088b86d8cb4b06a5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:30:09 -0800 Subject: [PATCH 176/230] debug --- src/olmo_core/nn/moe/moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index 066de8772..c00ee730d 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -189,6 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ + log.info(f"{x=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) From e19ace58337d05bb417036f8f0aad4941877b179 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:31:53 -0800 Subject: [PATCH 177/230] fix --- src/olmo_core/nn/moe/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/moe.py b/src/olmo_core/nn/moe/moe.py index c00ee730d..bc2a07a5c 100644 --- a/src/olmo_core/nn/moe/moe.py +++ b/src/olmo_core/nn/moe/moe.py @@ -189,7 +189,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :returns: The output of the MoE layer, the optional load-balancing loss, and the optional router Z-loss. """ - log.info(f"{x=}") expert_logits, expert_scores, expert_weights, expert_indices = self.router(x) out, batch_size_per_expert = self.experts(x, expert_weights, expert_indices) @@ -234,6 +233,7 @@ def apply_tp( parallelize_plan=PrepareModuleInput( input_layouts=None if input_layout is None else (input_layout,), desired_input_layouts=(Shard(1),), + use_local_output=True, ), ) From 94d31780e1f11ec7337a7187cacf9ac9d676ea88 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:33:51 -0800 Subject: [PATCH 178/230] extra safety --- src/olmo_core/nn/moe/parallel_mlp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/moe/parallel_mlp.py b/src/olmo_core/nn/moe/parallel_mlp.py index 991a2ca31..cc2e81903 100644 --- a/src/olmo_core/nn/moe/parallel_mlp.py +++ b/src/olmo_core/nn/moe/parallel_mlp.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.distributed import DeviceMesh -from olmo_core.distributed.utils import get_world_size +from olmo_core.distributed.utils import get_local_tensor, get_world_size from olmo_core.utils import get_default_device, move_to_device from ..buffer_cache import BufferCache @@ -128,6 +128,7 @@ def forward( :returns: The output with the same shape as ``x`` and a tensor with shape ``(num_local_experts,)`` containing the number of items/tokens routed to each (local) expert. """ + x = get_local_tensor(x) in_shape = x.size() # Compute the experts. From f3de9c62fcf484230f79251c35088572a13796b5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 16:40:30 -0800 Subject: [PATCH 179/230] fix router test --- src/test/nn/moe/router_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/nn/moe/router_test.py b/src/test/nn/moe/router_test.py index 81f271225..6c41b6089 100644 --- a/src/test/nn/moe/router_test.py +++ b/src/test/nn/moe/router_test.py @@ -21,7 +21,8 @@ def test_router(device: torch.device, uniform_expert_assignment: bool): uniform_expert_assignment=uniform_expert_assignment, ).to(device) x = torch.randn((2, 4, 128), device=device) - logits, weights, indices = router(x) + logits, scores, weights, indices = router(x) assert logits.shape == (8, 4) + assert scores.shape == (8, 4) assert weights.shape == (8, 2) assert indices.shape == (8, 2) From 9c9f5e787e78089c59ac28b1d20b275ccb9cdab5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 21:57:32 -0800 Subject: [PATCH 180/230] implement loss parallel --- src/olmo_core/nn/cross_entropy_loss.py | 104 ++++++++++++++++++ .../nn/functional/cross_entropy_loss.py | 30 +++++ src/olmo_core/nn/transformer/model.py | 2 +- .../train/train_module/transformer.py | 52 ++++----- 4 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 src/olmo_core/nn/cross_entropy_loss.py diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py new file mode 100644 index 000000000..4b43616fa --- /dev/null +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -0,0 +1,104 @@ +import logging +from typing import Literal, Optional, Tuple + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.tensor import Placement, Shard +from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module + +from .functional import cross_entropy_loss, fused_cross_entropy_loss + +log = logging.getLogger(__name__) + + +class CrossEntropyLoss(nn.Module): + def __init__( + self, + *, + ignore_index: int = -100, + reduction: Literal["mean", "sum", "none"] = "mean", + z_loss_multiplier: Optional[float] = None, + compile: bool = False, + fused: bool = False, + ): + super().__init__() + self.ignore_index = ignore_index + self.reduction: Literal["mean", "sum", "none"] = reduction + self.z_loss_multiplier = z_loss_multiplier + self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss + if compile: + if torch.cuda.is_available(): + log.info("Compiling loss function...") + self.base_loss_fn = torch.compile(self.base_loss_fn) + else: + log.warning("Skipping loss compilation since CUDA is not available") + + def forward( + self, + logits: torch.Tensor, + labels: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # shape: (batch_size, seq_len - 1, vocab_size) + logits_for_loss = logits[..., :-1, :].contiguous() + + # shape: (batch_size * (seq_len - 1), vocab_size) + logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) + + # shape: (batch_size, seq_len - 1) -> (batch_size * (seq_len - 1),) + labels_for_loss = labels.view(-1) + + # shape: depends on reduction + ce_loss, z_loss = self.base_loss_fn( + logits_for_loss, + labels_for_loss, + ignore_index=self.ignore_index, + reduction=self.reduction, + compute_z_loss=self.z_loss_multiplier is not None, + z_loss_multiplier=self.z_loss_multiplier or 1e-4, + ) + + if self.reduction == "none": + ce_loss = ce_loss.view(logits.shape[:-1]) + if z_loss is not None: + z_loss = z_loss.view(logits.shape[:-1]) + + return ce_loss, z_loss + + def apply_tp( + self, + tp_mesh: DeviceMesh, + input_layout: Optional[Placement] = None, + shard_dimension: int = 1, + # output_layout: Optional[Placement] = None, + # use_local_output: bool = True, + ): + if self.reduction == "none": + raise NotImplementedError(self.reduction) + + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + input_layouts=None if input_layout is None else (input_layout, input_layout), # type: ignore + desired_input_layouts=(Shard(shard_dimension),), + use_local_output=True, + ), + ) + + # output_layout = output_layout or Replicate() + # parallelize_module( + # self, + # device_mesh=tp_mesh, + # parallelize_plan=PrepareModuleOutput( + # output_layouts=( # type: ignore + # Shard(0), + # None if self.z_loss_multiplier is None else Shard(0), + # ), + # desired_output_layouts=( # type: ignore + # output_layout, + # None if self.z_loss_multiplier is None else output_layout, + # ), + # use_local_output=use_local_output, + # ), + # ) diff --git a/src/olmo_core/nn/functional/cross_entropy_loss.py b/src/olmo_core/nn/functional/cross_entropy_loss.py index 205a708b7..9c29a81cd 100644 --- a/src/olmo_core/nn/functional/cross_entropy_loss.py +++ b/src/olmo_core/nn/functional/cross_entropy_loss.py @@ -1,11 +1,41 @@ +import logging from typing import Callable, Literal, Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F __all__ = ["cross_entropy_loss", "fused_cross_entropy_loss"] +log = logging.getLogger(__name__) + + +class CrossEntropyLoss(nn.Module): + def __init__( + self, + *, + ignore_index: int = -100, + reduction: Literal["mean", "sum", "none"] = "mean", + compute_z_loss: bool = False, + z_loss_multiplier: float = 1e-4, + compile: bool = False, + fused: bool = False, + ): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction + self.compute_z_loss = compute_z_loss + self.z_loss_multiplier = z_loss_multiplier + self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss + if compile: + if torch.cuda.is_available(): + log.info("Compiling loss function...") + self.base_loss_fn = torch.compile(self.base_loss_fn) + else: + log.warning("Skipping loss compilation since CUDA is not available") + + def cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 063c6fc94..562fb0bb8 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -307,7 +307,7 @@ def apply_tp( if self.lm_head is not None: self.lm_head.apply_tp( tp_mesh, - output_layout=Shard(-1) if loss_parallel else Replicate(), + output_layout=Shard(1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index dc94a89a9..361399360 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -31,10 +31,7 @@ from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler -from olmo_core.nn.functional.cross_entropy_loss import ( - cross_entropy_loss, - fused_cross_entropy_loss, -) +from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss from olmo_core.nn.transformer import ( MoETransformer, NormalizedTransformer, @@ -265,6 +262,7 @@ def __init__( # Validate some options. if fused_loss and compile_loss: raise OLMoConfigurationError("'fused_loss' is not compatible with 'compile_loss'") + if rank_microbatch_size % max_sequence_length != 0: raise OLMoConfigurationError( f"'rank_microbatch_size' ({rank_microbatch_size:,d} tokens) must be divisible by " @@ -277,12 +275,20 @@ def __init__( ) log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") - self.base_loss_fn = fused_cross_entropy_loss if fused_loss else cross_entropy_loss - if compile_loss: - if torch.cuda.is_available(): - self.base_loss_fn = torch.compile(self.base_loss_fn) - else: - log.warning("Skipping loss compilation since CUDA is not available") + self.label_ignore_index = label_ignore_index + self._train_loss_fn = CrossEntropyLoss( + ignore_index=label_ignore_index, + reduction="sum", + z_loss_multiplier=z_loss_multiplier, + compile=compile_loss, + fused=fused_loss, + ) + self._eval_loss_fn = CrossEntropyLoss( + ignore_index=label_ignore_index, + reduction="none", + compile=compile_loss, + fused=fused_loss, + ) self.float8_handler: Optional[Float8Handler] = None float8_enabled = False @@ -308,8 +314,11 @@ def __init__( self.model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=False, # TODO (epwalsh): figure out if this will work w/ z-loss + loss_parallel=True, ) + self._train_loss_fn.apply_tp(tp_mesh) + # TODO: parallel eval loss? The tricky part is we don't reduce it. + tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" @@ -373,7 +382,6 @@ def __init__( self.rank_microbatch_size = rank_microbatch_size self.max_sequence_length = max_sequence_length - self.z_loss_multiplier = z_loss_multiplier self.autocast_precision = autocast_precision self.max_grad_norm = max_grad_norm self.scheduler = scheduler @@ -384,7 +392,6 @@ def __init__( flatten_optimizer_state_dict=True, strict=True ) self.load_key_mapping = load_key_mapping - self.label_ignore_index = label_ignore_index @property def dp_process_group(self) -> Optional[dist.ProcessGroup]: @@ -399,19 +406,10 @@ def eval_batch_spec(self) -> EvalBatchSpec: def loss_fn( self, logits: torch.Tensor, labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) - # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' # (the total number of tokens used in the loss across the whole batch, not just the micro batch) # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss = self.base_loss_fn( - logits_for_loss, - labels_for_loss, - ignore_index=self.label_ignore_index, - reduction="sum", - compute_z_loss=self.z_loss_multiplier is not None, - z_loss_multiplier=self.z_loss_multiplier or 1e-4, - ) + ce_loss, z_loss = self._train_loss_fn(logits, labels) ce_loss.div_(batch_num_tokens_for_loss) if z_loss is not None: @@ -429,13 +427,7 @@ def loss_fn( ) def eval_loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) - ce_loss, _ = self.base_loss_fn( - logits_for_loss, - labels_for_loss, - ignore_index=self.label_ignore_index, - reduction="none", - ) + ce_loss, _ = self._eval_loss_fn(logits, labels) return ce_loss.view(logits.shape[0], -1) def on_attach(self): From 7d7249e8540f90a11be8746a0fb92659d1bae9a7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 7 Feb 2025 22:04:30 -0800 Subject: [PATCH 181/230] do in pipeline too --- .../nn/functional/cross_entropy_loss.py | 30 ---------- .../train/train_module/transformer.py | 7 +-- .../train_module/transformer_pipeline.py | 55 ++++++++----------- 3 files changed, 26 insertions(+), 66 deletions(-) diff --git a/src/olmo_core/nn/functional/cross_entropy_loss.py b/src/olmo_core/nn/functional/cross_entropy_loss.py index 9c29a81cd..205a708b7 100644 --- a/src/olmo_core/nn/functional/cross_entropy_loss.py +++ b/src/olmo_core/nn/functional/cross_entropy_loss.py @@ -1,41 +1,11 @@ -import logging from typing import Callable, Literal, Optional, Tuple import torch -import torch.nn as nn import torch.nn.functional as F __all__ = ["cross_entropy_loss", "fused_cross_entropy_loss"] -log = logging.getLogger(__name__) - - -class CrossEntropyLoss(nn.Module): - def __init__( - self, - *, - ignore_index: int = -100, - reduction: Literal["mean", "sum", "none"] = "mean", - compute_z_loss: bool = False, - z_loss_multiplier: float = 1e-4, - compile: bool = False, - fused: bool = False, - ): - super().__init__() - self.ignore_index = ignore_index - self.reduction = reduction - self.compute_z_loss = compute_z_loss - self.z_loss_multiplier = z_loss_multiplier - self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss - if compile: - if torch.cuda.is_available(): - log.info("Compiling loss function...") - self.base_loss_fn = torch.compile(self.base_loss_fn) - else: - log.warning("Skipping loss compilation since CUDA is not available") - - def cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 361399360..c86e8f218 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -43,7 +43,7 @@ from olmo_core.optim.scheduler import Scheduler from olmo_core.utils import gc_cuda, get_default_device, mark_dynamic, move_to_device -from ..common import ReduceType, reshape_inputs_for_loss +from ..common import ReduceType from .train_module import EvalBatchSpec, TrainModule log = logging.getLogger(__name__) @@ -524,7 +524,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Batch losses to record. ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) z_batch_loss: Optional[torch.Tensor] = None - if self.z_loss_multiplier is not None: + if self._train_loss_fn.z_loss_multiplier is not None: z_batch_loss = move_to_device(torch.tensor(0.0), self.device) auxiliary_batch_losses: Dict[str, torch.Tensor] = {} @@ -579,8 +579,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Record loss metrics. self.record_ce_loss(ce_batch_loss, ReduceType.mean) - if self.z_loss_multiplier is not None: - assert z_batch_loss is not None + if z_batch_loss is not None: self.record_metric( "Z loss", z_batch_loss, diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 02773ae7b..a95b758ed 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -33,16 +33,13 @@ from olmo_core.distributed.utils import get_local_tensor, get_world_size from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler -from olmo_core.nn.functional.cross_entropy_loss import ( - cross_entropy_loss, - fused_cross_entropy_loss, -) +from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss from olmo_core.nn.transformer import NormalizedTransformer, Transformer from olmo_core.optim import OptimConfig, SkipStepOptimizer from olmo_core.optim.scheduler import Scheduler from olmo_core.utils import gc_cuda, get_default_device, mark_dynamic, move_to_device -from ..common import ReduceType, reshape_inputs_for_loss +from ..common import ReduceType from .train_module import EvalBatchSizeUnit, EvalBatchSpec, TrainModule from .transformer import ( TransformerActivationCheckpointingConfig, @@ -311,6 +308,7 @@ def __init__( # Validate some options. if fused_loss and compile_loss: raise OLMoConfigurationError("'fused_loss' is not compatible with 'compile_loss'") + if rank_microbatch_size % max_sequence_length != 0: raise OLMoConfigurationError( f"'rank_microbatch_size' ({rank_microbatch_size:,d} tokens) must be divisible by " @@ -323,12 +321,20 @@ def __init__( ) log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") - self.base_loss_fn = fused_cross_entropy_loss if fused_loss else cross_entropy_loss - if compile_loss: - if torch.cuda.is_available(): - self.base_loss_fn = torch.compile(self.base_loss_fn) - else: - log.warning("Skipping loss compilation since CUDA is not available") + self.label_ignore_index = label_ignore_index + self._train_loss_fn = CrossEntropyLoss( + ignore_index=label_ignore_index, + reduction="sum", + z_loss_multiplier=z_loss_multiplier, + compile=compile_loss, + fused=fused_loss, + ) + self._eval_loss_fn = CrossEntropyLoss( + ignore_index=label_ignore_index, + reduction="none", + compile=compile_loss, + fused=fused_loss, + ) self.float8_handler: Optional[Float8Handler] = None float8_enabled = False @@ -365,8 +371,10 @@ def __init__( model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=False, # TODO (epwalsh): figure out if this will work w/ z-loss + loss_parallel=True, ) + self._train_loss_fn.apply_tp(tp_mesh) + # TODO: parallel eval loss? The tricky part is we don't reduce it. tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" @@ -431,7 +439,6 @@ def __init__( self.rank_microbatch_size = rank_microbatch_size self.max_sequence_length = max_sequence_length - self.z_loss_multiplier = z_loss_multiplier self.autocast_precision = autocast_precision self.max_grad_norm = max_grad_norm self.scheduler = scheduler @@ -442,7 +449,6 @@ def __init__( flatten_optimizer_state_dict=True, strict=False ) self.load_key_mapping = load_key_mapping - self.label_ignore_index = label_ignore_index for model in self.model_parts: if model.is_moe: @@ -488,19 +494,10 @@ def eval_pp_schedule(self) -> PipelineSchedule: def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: assert self._batch_num_tokens_for_loss is not None - logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) - # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' # (the total number of tokens used in the loss across the whole batch, not just the micro batch) # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss = self.base_loss_fn( - logits_for_loss, - labels_for_loss, - ignore_index=self.label_ignore_index, - reduction="sum", - compute_z_loss=self.z_loss_multiplier is not None, - z_loss_multiplier=self.z_loss_multiplier or 1e-4, - ) + ce_loss, z_loss = self._train_loss_fn(logits, labels) ce_loss.div_(self._batch_num_tokens_for_loss) if z_loss is not None: @@ -525,13 +522,7 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return loss def eval_loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - logits_for_loss, labels_for_loss = reshape_inputs_for_loss(logits, labels) - ce_loss, _ = self.base_loss_fn( - logits_for_loss, - labels_for_loss, - ignore_index=self.label_ignore_index, - reduction="none", - ) + ce_loss, _ = self._eval_loss_fn(logits, labels) return ce_loss.view(logits.shape[0], -1) def on_attach(self): @@ -673,7 +664,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): self.record_ce_loss( self._ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum ) - if self.z_loss_multiplier is not None: + if self._train_loss_fn.z_loss_multiplier is not None: if self._z_batch_loss is None: self.record_metric("Z loss", 0.0, ReduceType.sum, namespace="train") else: From 3994f4114ecfded4f5765ea37d106f7d0f836b58 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 12:20:54 -0800 Subject: [PATCH 182/230] start test for CE loss --- src/olmo_core/nn/cross_entropy_loss.py | 56 +++++++++++++++++--------- src/test/nn/cross_entropy_loss_test.py | 53 ++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 20 deletions(-) create mode 100644 src/test/nn/cross_entropy_loss_test.py diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 4b43616fa..7076a1c6e 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -4,8 +4,12 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed.tensor import Placement, Shard -from torch.distributed.tensor.parallel import PrepareModuleInput, parallelize_module +from torch.distributed.tensor import Placement, Replicate, Shard +from torch.distributed.tensor.parallel import ( + PrepareModuleInput, + PrepareModuleOutput, + parallelize_module, +) from .functional import cross_entropy_loss, fused_cross_entropy_loss @@ -27,6 +31,7 @@ def __init__( self.reduction: Literal["mean", "sum", "none"] = reduction self.z_loss_multiplier = z_loss_multiplier self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss + self._tp_enabled: bool = False if compile: if torch.cuda.is_available(): log.info("Compiling loss function...") @@ -34,6 +39,10 @@ def __init__( else: log.warning("Skipping loss compilation since CUDA is not available") + @property + def tp_enabled(self) -> bool: + return self._tp_enabled + def forward( self, logits: torch.Tensor, @@ -62,6 +71,10 @@ def forward( ce_loss = ce_loss.view(logits.shape[:-1]) if z_loss is not None: z_loss = z_loss.view(logits.shape[:-1]) + elif self.tp_enabled: + ce_loss = ce_loss.unsqueeze(0) + if z_loss is not None: + z_loss = z_loss.unsqueeze(0) return ce_loss, z_loss @@ -70,8 +83,8 @@ def apply_tp( tp_mesh: DeviceMesh, input_layout: Optional[Placement] = None, shard_dimension: int = 1, - # output_layout: Optional[Placement] = None, - # use_local_output: bool = True, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, ): if self.reduction == "none": raise NotImplementedError(self.reduction) @@ -86,19 +99,22 @@ def apply_tp( ), ) - # output_layout = output_layout or Replicate() - # parallelize_module( - # self, - # device_mesh=tp_mesh, - # parallelize_plan=PrepareModuleOutput( - # output_layouts=( # type: ignore - # Shard(0), - # None if self.z_loss_multiplier is None else Shard(0), - # ), - # desired_output_layouts=( # type: ignore - # output_layout, - # None if self.z_loss_multiplier is None else output_layout, - # ), - # use_local_output=use_local_output, - # ), - # ) + expected_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) + desired_output_layout = output_layout or Replicate() + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleOutput( + output_layouts=( # type: ignore + expected_output_layout, + None if self.z_loss_multiplier is None else expected_output_layout, + ), + desired_output_layouts=( # type: ignore + desired_output_layout, + None if self.z_loss_multiplier is None else desired_output_layout, + ), + use_local_output=use_local_output, + ), + ) + + self._tp_enabled = True diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py new file mode 100644 index 000000000..c500883da --- /dev/null +++ b/src/test/nn/cross_entropy_loss_test.py @@ -0,0 +1,53 @@ +from typing import Literal, Optional + +import pytest +import torch + +from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss + +from ..distributed.utils import requires_multi_gpu + + +@pytest.mark.parametrize( + "fused, compile, reduction", + [ + (pytest.param(False, id="default"), pytest.param(False, id="no-compile"), "sum"), + ], +) +@requires_multi_gpu +def test_cross_entropy_loss_parallel( + fused: bool, + compile: bool, + reduction: Literal["sum", "mean", "none"], + z_loss_multiplier: Optional[float] = None, +): + loss_fn = CrossEntropyLoss( + reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier + ) + + B, S, D = 4, 16, 64 + input_ids = torch.randint(0, 256, (B, S), device="cuda") + logits = torch.randn(B, S, D, device="cuda", requires_grad=True) + labels = input_ids.clone()[..., 1:].contiguous() + labels[0][2] = -100 + labels[2][9] = -100 + labels[3][12] = -100 + + batch_num_tokens_for_loss = (labels != -100).sum() + ce_loss, z_loss = loss_fn(logits, labels) + ce_loss.div_(batch_num_tokens_for_loss) + if z_loss is not None: + z_loss.div_(batch_num_tokens_for_loss) + + loss = ce_loss + if z_loss is not None: + loss += z_loss + + if reduction != "none": + assert loss.shape == tuple() + else: + assert loss.shape == (B, S - 1) + + # Trigger backward pass. + loss.backward() + assert logits.grad is not None From 5deafea0e54a3736deaa5cad8ed135418f15dca8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 12:22:00 -0800 Subject: [PATCH 183/230] clean up --- src/test/nn/cross_entropy_loss_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index c500883da..6c1b5e2b7 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "fused, compile, reduction", [ - (pytest.param(False, id="default"), pytest.param(False, id="no-compile"), "sum"), + pytest.param(False, False, "sum", id="default-sum"), ], ) @requires_multi_gpu From 955028611fa7e24941732c175095242a354e2e64 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 12:23:32 -0800 Subject: [PATCH 184/230] fix --- src/test/nn/cross_entropy_loss_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 6c1b5e2b7..a9ef0d67b 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -25,9 +25,9 @@ def test_cross_entropy_loss_parallel( reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier ) - B, S, D = 4, 16, 64 - input_ids = torch.randint(0, 256, (B, S), device="cuda") - logits = torch.randn(B, S, D, device="cuda", requires_grad=True) + B, S, V = 4, 16, 256 + input_ids = torch.randint(0, V, (B, S), device="cuda") + logits = torch.randn(B, S, V, device="cuda", requires_grad=True) labels = input_ids.clone()[..., 1:].contiguous() labels[0][2] = -100 labels[2][9] = -100 From 2c7219ffe821d3b3bd2dd15f76604fdc82742c6c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 12:42:11 -0800 Subject: [PATCH 185/230] clean up --- src/olmo_core/nn/cross_entropy_loss.py | 18 ++--- .../train/train_module/transformer.py | 6 +- src/test/nn/cross_entropy_loss_test.py | 72 +++++++++++++++++-- 3 files changed, 81 insertions(+), 15 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 7076a1c6e..650d92f5f 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -48,16 +48,16 @@ def forward( logits: torch.Tensor, labels: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # shape: (batch_size, seq_len - 1, vocab_size) - logits_for_loss = logits[..., :-1, :].contiguous() - - # shape: (batch_size * (seq_len - 1), vocab_size) - logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) - - # shape: (batch_size, seq_len - 1) -> (batch_size * (seq_len - 1),) + """ + Compute the CE loss and optionally Z-loss. + + :param logits: The logits of shape ``(*, num_classes)``. + :param labels: The target labels of shape ``(*, )``. + """ + # Flatten inputs for loss function. + logits_for_loss = logits.view(-1, logits.size(-1)) labels_for_loss = labels.view(-1) - # shape: depends on reduction ce_loss, z_loss = self.base_loss_fn( logits_for_loss, labels_for_loss, @@ -84,7 +84,7 @@ def apply_tp( input_layout: Optional[Placement] = None, shard_dimension: int = 1, output_layout: Optional[Placement] = None, - use_local_output: bool = True, + use_local_output: bool = False, ): if self.reduction == "none": raise NotImplementedError(self.reduction) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index c86e8f218..2218b030f 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -544,7 +544,9 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Get loss to optimize for, and the separate detached CE and Z loss values. loss, ce_loss, z_loss = self.loss_fn( - logits, micro_batch["labels"], batch_num_tokens_for_loss + logits[..., :-1, :].contiguous(), + micro_batch["labels"], + batch_num_tokens_for_loss, ) del logits @@ -608,7 +610,7 @@ def eval_batch( logits = self.model_forward(batch) loss: Optional[torch.Tensor] = None if labels is not None: - loss = self.eval_loss_fn(logits, labels) + loss = self.eval_loss_fn(logits[..., :-1, :].contiguous(), labels) return logits, loss diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index a9ef0d67b..9fdf05545 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -2,10 +2,57 @@ import pytest import torch +from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed.tensor import Placement, Replicate, Shard, distribute_tensor +from olmo_core.distributed.utils import get_world_size from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss +from olmo_core.utils import get_default_device -from ..distributed.utils import requires_multi_gpu +from ..distributed.utils import requires_multi_gpu, run_distributed_test + + +def run_cross_entropy_loss_parallel( + fused: bool, + compile: bool, + reduction: Literal["sum", "mean", "none"], + z_loss_multiplier: Optional[float], + logits: torch.Tensor, + labels: torch.Tensor, + batch_num_tokens_for_loss: torch.Tensor, +): + tp_mesh = init_device_mesh("cuda", (get_world_size(),), mesh_dim_names=("tp",)) + + logits = distribute_tensor( + logits.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) + ) + labels = distribute_tensor( + labels.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) + ) + batch_num_tokens_for_loss = batch_num_tokens_for_loss.to(device=get_default_device()) + + loss_fn = CrossEntropyLoss( + reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier + ) + loss_fn.apply_tp(tp_mesh) + + ce_loss, z_loss = loss_fn(logits[..., :-1, :].contiguous(), labels) + ce_loss.div_(batch_num_tokens_for_loss) + if z_loss is not None: + z_loss.div_(batch_num_tokens_for_loss) + + loss = ce_loss + if z_loss is not None: + loss += z_loss + + if reduction != "none": + assert loss.shape == tuple() + else: + assert loss.shape == labels.shape + + # Trigger backward pass. + loss.backward() + assert logits.grad is not None @pytest.mark.parametrize( @@ -21,11 +68,12 @@ def test_cross_entropy_loss_parallel( reduction: Literal["sum", "mean", "none"], z_loss_multiplier: Optional[float] = None, ): + B, S, V = 4, 16, 256 + loss_fn = CrossEntropyLoss( reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier ) - B, S, V = 4, 16, 256 input_ids = torch.randint(0, V, (B, S), device="cuda") logits = torch.randn(B, S, V, device="cuda", requires_grad=True) labels = input_ids.clone()[..., 1:].contiguous() @@ -34,7 +82,7 @@ def test_cross_entropy_loss_parallel( labels[3][12] = -100 batch_num_tokens_for_loss = (labels != -100).sum() - ce_loss, z_loss = loss_fn(logits, labels) + ce_loss, z_loss = loss_fn(logits[..., :-1, :].contiguous(), labels) ce_loss.div_(batch_num_tokens_for_loss) if z_loss is not None: z_loss.div_(batch_num_tokens_for_loss) @@ -46,8 +94,24 @@ def test_cross_entropy_loss_parallel( if reduction != "none": assert loss.shape == tuple() else: - assert loss.shape == (B, S - 1) + assert loss.shape == labels.shape # Trigger backward pass. loss.backward() assert logits.grad is not None + + run_distributed_test( + run_cross_entropy_loss_parallel, + world_size=2, + backend="nccl", + start_method="spawn", + func_args=( + fused, + compile, + reduction, + z_loss_multiplier, + logits.detach().cpu(), + labels.detach().cpu(), + batch_num_tokens_for_loss.detach().cpu(), + ), + ) From cfcc2a11c4f0ffb7574a7ab4d37b9e7130f22db4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 12:47:44 -0800 Subject: [PATCH 186/230] fix? --- src/olmo_core/nn/cross_entropy_loss.py | 4 ++-- src/olmo_core/train/train_module/transformer_pipeline.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 650d92f5f..77c29c61d 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -55,8 +55,8 @@ def forward( :param labels: The target labels of shape ``(*, )``. """ # Flatten inputs for loss function. - logits_for_loss = logits.view(-1, logits.size(-1)) - labels_for_loss = labels.view(-1) + logits_for_loss = logits.reshape(-1, logits.size(-1)) + labels_for_loss = labels.reshape(-1) ce_loss, z_loss = self.base_loss_fn( logits_for_loss, diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index a95b758ed..08da4524a 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -497,7 +497,7 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' # (the total number of tokens used in the loss across the whole batch, not just the micro batch) # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss = self._train_loss_fn(logits, labels) + ce_loss, z_loss = self._train_loss_fn(logits[..., :-1, :].contiguous(), labels) ce_loss.div_(self._batch_num_tokens_for_loss) if z_loss is not None: @@ -522,7 +522,7 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return loss def eval_loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - ce_loss, _ = self._eval_loss_fn(logits, labels) + ce_loss, _ = self._eval_loss_fn(logits[..., :-1, :].contiguous(), labels) return ce_loss.view(logits.shape[0], -1) def on_attach(self): From 041bbcd75f1994f05fdc21e00741ed84f227a70c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 13:52:22 -0800 Subject: [PATCH 187/230] clean up --- src/olmo_core/nn/cross_entropy_loss.py | 101 ++++++------------ .../train/train_module/transformer.py | 20 ++-- .../train_module/transformer_pipeline.py | 17 ++- src/test/nn/cross_entropy_loss_test.py | 73 +------------ 4 files changed, 52 insertions(+), 159 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 77c29c61d..b2223e1f2 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -3,20 +3,17 @@ import torch import torch.nn as nn -from torch.distributed import DeviceMesh -from torch.distributed.tensor import Placement, Replicate, Shard -from torch.distributed.tensor.parallel import ( - PrepareModuleInput, - PrepareModuleOutput, - parallelize_module, -) from .functional import cross_entropy_loss, fused_cross_entropy_loss log = logging.getLogger(__name__) -class CrossEntropyLoss(nn.Module): +class LMCrossEntropyLoss(nn.Module): + """ + Cross-entropy loss for language modeling. + """ + def __init__( self, *, @@ -27,22 +24,21 @@ def __init__( fused: bool = False, ): super().__init__() + + if compile and fused: + log.warning(f"{self.__class__.__name__} with fused+compile is experimental") + self.ignore_index = ignore_index self.reduction: Literal["mean", "sum", "none"] = reduction self.z_loss_multiplier = z_loss_multiplier self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss - self._tp_enabled: bool = False if compile: if torch.cuda.is_available(): log.info("Compiling loss function...") - self.base_loss_fn = torch.compile(self.base_loss_fn) + self.compile() else: log.warning("Skipping loss compilation since CUDA is not available") - @property - def tp_enabled(self) -> bool: - return self._tp_enabled - def forward( self, logits: torch.Tensor, @@ -51,12 +47,30 @@ def forward( """ Compute the CE loss and optionally Z-loss. - :param logits: The logits of shape ``(*, num_classes)``. - :param labels: The target labels of shape ``(*, )``. + :param logits: The logits of shape ``(B, S, V)``. + The final logits in the sequence dimension will be discarded to match the shape of + the labels. + :param labels: The target labels of shape ``(B, S-1)``. """ - # Flatten inputs for loss function. - logits_for_loss = logits.reshape(-1, logits.size(-1)) - labels_for_loss = labels.reshape(-1) + if len(logits.shape) != 3: + raise RuntimeError( + f"expected logits to have shape (B, S, V) but found {tuple(logits.shape)} instead" + ) + + B, S, V = logits.shape + if labels.shape != (B, S - 1): + raise RuntimeError( + f"expected labels to have shape (B, S-1) = {(B, S-1)}, but found {tuple(labels.shape)} instead" + ) + + # shape: (B, S - 1, V) + logits_for_loss = logits[..., :-1, :].contiguous() + + # shape: (B * (S - 1), V) + logits_for_loss = logits_for_loss.view(-1, V) + + # shape: (B, S - 1) -> (B * (S - 1),) + labels_for_loss = labels.view(-1) ce_loss, z_loss = self.base_loss_fn( logits_for_loss, @@ -68,53 +82,8 @@ def forward( ) if self.reduction == "none": - ce_loss = ce_loss.view(logits.shape[:-1]) - if z_loss is not None: - z_loss = z_loss.view(logits.shape[:-1]) - elif self.tp_enabled: - ce_loss = ce_loss.unsqueeze(0) + ce_loss = ce_loss.view(labels.shape) if z_loss is not None: - z_loss = z_loss.unsqueeze(0) + z_loss = z_loss.view(labels.shape) return ce_loss, z_loss - - def apply_tp( - self, - tp_mesh: DeviceMesh, - input_layout: Optional[Placement] = None, - shard_dimension: int = 1, - output_layout: Optional[Placement] = None, - use_local_output: bool = False, - ): - if self.reduction == "none": - raise NotImplementedError(self.reduction) - - parallelize_module( - self, - device_mesh=tp_mesh, - parallelize_plan=PrepareModuleInput( - input_layouts=None if input_layout is None else (input_layout, input_layout), # type: ignore - desired_input_layouts=(Shard(shard_dimension),), - use_local_output=True, - ), - ) - - expected_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) - desired_output_layout = output_layout or Replicate() - parallelize_module( - self, - device_mesh=tp_mesh, - parallelize_plan=PrepareModuleOutput( - output_layouts=( # type: ignore - expected_output_layout, - None if self.z_loss_multiplier is None else expected_output_layout, - ), - desired_output_layouts=( # type: ignore - desired_output_layout, - None if self.z_loss_multiplier is None else desired_output_layout, - ), - use_local_output=use_local_output, - ), - ) - - self._tp_enabled = True diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 2218b030f..dff5cf051 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -31,7 +31,7 @@ from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler -from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss +from olmo_core.nn.cross_entropy_loss import LMCrossEntropyLoss from olmo_core.nn.transformer import ( MoETransformer, NormalizedTransformer, @@ -260,9 +260,6 @@ def __init__( super().__init__() # Validate some options. - if fused_loss and compile_loss: - raise OLMoConfigurationError("'fused_loss' is not compatible with 'compile_loss'") - if rank_microbatch_size % max_sequence_length != 0: raise OLMoConfigurationError( f"'rank_microbatch_size' ({rank_microbatch_size:,d} tokens) must be divisible by " @@ -276,14 +273,14 @@ def __init__( log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") self.label_ignore_index = label_ignore_index - self._train_loss_fn = CrossEntropyLoss( + self._train_loss_fn = LMCrossEntropyLoss( ignore_index=label_ignore_index, reduction="sum", z_loss_multiplier=z_loss_multiplier, compile=compile_loss, fused=fused_loss, ) - self._eval_loss_fn = CrossEntropyLoss( + self._eval_loss_fn = LMCrossEntropyLoss( ignore_index=label_ignore_index, reduction="none", compile=compile_loss, @@ -314,11 +311,8 @@ def __init__( self.model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=True, + loss_parallel=False, ) - self._train_loss_fn.apply_tp(tp_mesh) - # TODO: parallel eval loss? The tricky part is we don't reduce it. - tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" @@ -544,9 +538,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Get loss to optimize for, and the separate detached CE and Z loss values. loss, ce_loss, z_loss = self.loss_fn( - logits[..., :-1, :].contiguous(), - micro_batch["labels"], - batch_num_tokens_for_loss, + logits, micro_batch["labels"], batch_num_tokens_for_loss ) del logits @@ -610,7 +602,7 @@ def eval_batch( logits = self.model_forward(batch) loss: Optional[torch.Tensor] = None if labels is not None: - loss = self.eval_loss_fn(logits[..., :-1, :].contiguous(), labels) + loss = self.eval_loss_fn(logits, labels) return logits, loss diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 08da4524a..129329172 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -33,7 +33,7 @@ from olmo_core.distributed.utils import get_local_tensor, get_world_size from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler -from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss +from olmo_core.nn.cross_entropy_loss import LMCrossEntropyLoss from olmo_core.nn.transformer import NormalizedTransformer, Transformer from olmo_core.optim import OptimConfig, SkipStepOptimizer from olmo_core.optim.scheduler import Scheduler @@ -306,9 +306,6 @@ def __init__( super().__init__() # Validate some options. - if fused_loss and compile_loss: - raise OLMoConfigurationError("'fused_loss' is not compatible with 'compile_loss'") - if rank_microbatch_size % max_sequence_length != 0: raise OLMoConfigurationError( f"'rank_microbatch_size' ({rank_microbatch_size:,d} tokens) must be divisible by " @@ -322,14 +319,14 @@ def __init__( log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") self.label_ignore_index = label_ignore_index - self._train_loss_fn = CrossEntropyLoss( + self._train_loss_fn = LMCrossEntropyLoss( ignore_index=label_ignore_index, reduction="sum", z_loss_multiplier=z_loss_multiplier, compile=compile_loss, fused=fused_loss, ) - self._eval_loss_fn = CrossEntropyLoss( + self._eval_loss_fn = LMCrossEntropyLoss( ignore_index=label_ignore_index, reduction="none", compile=compile_loss, @@ -371,10 +368,8 @@ def __init__( model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=True, + loss_parallel=False, ) - self._train_loss_fn.apply_tp(tp_mesh) - # TODO: parallel eval loss? The tricky part is we don't reduce it. tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" @@ -497,7 +492,7 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' # (the total number of tokens used in the loss across the whole batch, not just the micro batch) # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss = self._train_loss_fn(logits[..., :-1, :].contiguous(), labels) + ce_loss, z_loss = self._train_loss_fn(logits, labels) ce_loss.div_(self._batch_num_tokens_for_loss) if z_loss is not None: @@ -522,7 +517,7 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return loss def eval_loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - ce_loss, _ = self._eval_loss_fn(logits[..., :-1, :].contiguous(), labels) + ce_loss, _ = self._eval_loss_fn(logits, labels) return ce_loss.view(logits.shape[0], -1) def on_attach(self): diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 9fdf05545..f7084af14 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -2,57 +2,10 @@ import pytest import torch -from torch.distributed import DeviceMesh, init_device_mesh -from torch.distributed.tensor import Placement, Replicate, Shard, distribute_tensor -from olmo_core.distributed.utils import get_world_size -from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss -from olmo_core.utils import get_default_device +from olmo_core.nn.cross_entropy_loss import LMCrossEntropyLoss -from ..distributed.utils import requires_multi_gpu, run_distributed_test - - -def run_cross_entropy_loss_parallel( - fused: bool, - compile: bool, - reduction: Literal["sum", "mean", "none"], - z_loss_multiplier: Optional[float], - logits: torch.Tensor, - labels: torch.Tensor, - batch_num_tokens_for_loss: torch.Tensor, -): - tp_mesh = init_device_mesh("cuda", (get_world_size(),), mesh_dim_names=("tp",)) - - logits = distribute_tensor( - logits.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) - ) - labels = distribute_tensor( - labels.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) - ) - batch_num_tokens_for_loss = batch_num_tokens_for_loss.to(device=get_default_device()) - - loss_fn = CrossEntropyLoss( - reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier - ) - loss_fn.apply_tp(tp_mesh) - - ce_loss, z_loss = loss_fn(logits[..., :-1, :].contiguous(), labels) - ce_loss.div_(batch_num_tokens_for_loss) - if z_loss is not None: - z_loss.div_(batch_num_tokens_for_loss) - - loss = ce_loss - if z_loss is not None: - loss += z_loss - - if reduction != "none": - assert loss.shape == tuple() - else: - assert loss.shape == labels.shape - - # Trigger backward pass. - loss.backward() - assert logits.grad is not None +from ..utils import requires_gpu @pytest.mark.parametrize( @@ -61,7 +14,7 @@ def run_cross_entropy_loss_parallel( pytest.param(False, False, "sum", id="default-sum"), ], ) -@requires_multi_gpu +@requires_gpu def test_cross_entropy_loss_parallel( fused: bool, compile: bool, @@ -70,7 +23,7 @@ def test_cross_entropy_loss_parallel( ): B, S, V = 4, 16, 256 - loss_fn = CrossEntropyLoss( + loss_fn = LMCrossEntropyLoss( reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier ) @@ -82,7 +35,7 @@ def test_cross_entropy_loss_parallel( labels[3][12] = -100 batch_num_tokens_for_loss = (labels != -100).sum() - ce_loss, z_loss = loss_fn(logits[..., :-1, :].contiguous(), labels) + ce_loss, z_loss = loss_fn(logits, labels) ce_loss.div_(batch_num_tokens_for_loss) if z_loss is not None: z_loss.div_(batch_num_tokens_for_loss) @@ -99,19 +52,3 @@ def test_cross_entropy_loss_parallel( # Trigger backward pass. loss.backward() assert logits.grad is not None - - run_distributed_test( - run_cross_entropy_loss_parallel, - world_size=2, - backend="nccl", - start_method="spawn", - func_args=( - fused, - compile, - reduction, - z_loss_multiplier, - logits.detach().cpu(), - labels.detach().cpu(), - batch_num_tokens_for_loss.detach().cpu(), - ), - ) From 3cab69ce34dd7a2025dea7d9d3eb77e44611f54d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 13:56:31 -0800 Subject: [PATCH 188/230] add case for none reduction --- src/olmo_core/nn/transformer/model.py | 4 ++-- src/test/nn/cross_entropy_loss_test.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 562fb0bb8..73e54670f 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -296,7 +296,7 @@ def apply_tp( ), ) - # Apply tensor + sequence parallelism to every transformer block. + # Apply tensor/sequence parallelism to every transformer block. # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 @@ -307,7 +307,7 @@ def apply_tp( if self.lm_head is not None: self.lm_head.apply_tp( tp_mesh, - output_layout=Shard(1) if loss_parallel else Replicate(), + output_layout=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index f7084af14..36efd8a48 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -12,6 +12,7 @@ "fused, compile, reduction", [ pytest.param(False, False, "sum", id="default-sum"), + pytest.param(False, False, "none", id="default-none"), ], ) @requires_gpu @@ -48,6 +49,7 @@ def test_cross_entropy_loss_parallel( assert loss.shape == tuple() else: assert loss.shape == labels.shape + loss = loss.sum() # Trigger backward pass. loss.backward() From 30245a405e578243f884a02362c27229d967dbf2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:16:23 -0800 Subject: [PATCH 189/230] ok try again --- src/olmo_core/data/utils.py | 4 +- src/olmo_core/nn/cross_entropy_loss.py | 182 ++++++++++++++---- src/olmo_core/train/common.py | 10 +- .../train/train_module/transformer.py | 6 +- .../train_module/transformer_pipeline.py | 6 +- src/test/nn/cross_entropy_loss_test.py | 77 +++++++- 6 files changed, 226 insertions(+), 59 deletions(-) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index dd44d3afd..806eaa71c 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F from olmo_core.aliases import PathOrStr from olmo_core.io import add_cached_path_clients, get_bytes_range, is_url, resource_path @@ -467,4 +468,5 @@ def get_labels(batch: Dict[str, Any], label_ignore_index: int = -100) -> torch.T labels.masked_fill_(attention_mask == 0.0, label_ignore_index) if instance_mask is not None: labels.masked_fill_(~instance_mask.unsqueeze(-1), value=label_ignore_index) - return labels[..., 1:].contiguous() + # Shift and pad. + return F.pad(labels[..., 1:], (0, 1, 0, 0), value=label_ignore_index) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index b2223e1f2..6257924d6 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -3,15 +3,76 @@ import torch import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.tensor import Placement, Replicate, Shard +from torch.distributed.tensor.parallel import ( + PrepareModuleInput, + PrepareModuleOutput, + parallelize_module, +) + +from olmo_core.distributed.utils import get_local_tensor from .functional import cross_entropy_loss, fused_cross_entropy_loss log = logging.getLogger(__name__) -class LMCrossEntropyLoss(nn.Module): +class _InnerCELoss(nn.Module): + def __init__( + self, + ignore_index: int = -100, + reduction: Literal["mean", "sum", "none"] = "mean", + z_loss_multiplier: Optional[float] = None, + fused: bool = False, + ): + super().__init__() + self.ignore_index = ignore_index + self.reduction: Literal["mean", "sum", "none"] = reduction + self.z_loss_multiplier = z_loss_multiplier + self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss + self.tp_enabled = False + + def forward( + self, + logits: torch.Tensor, + labels: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if logits.shape[:-1] != labels.shape: + raise RuntimeError( + f"expected labels to have shape {logits.shape[:-1]}, but found {tuple(labels.shape)} instead" + ) + + # shape: (B * S, V) + logits_for_loss = logits.view(-1, logits.shape[-1]) + + # shape: (B, S) -> (B * S,) + labels_for_loss = labels.view(-1) + + ce_loss, z_loss = self.base_loss_fn( + logits_for_loss, + labels_for_loss, + ignore_index=self.ignore_index, + reduction=self.reduction, + compute_z_loss=self.z_loss_multiplier is not None, + z_loss_multiplier=self.z_loss_multiplier or 1e-4, + ) + + if self.reduction == "none": + ce_loss = ce_loss.view(labels.shape) + if z_loss is not None: + z_loss = z_loss.view(labels.shape) + elif self.tp_enabled: + ce_loss = ce_loss.unsqueeze(0) + if z_loss is not None: + z_loss = z_loss.unsqueeze(0) + + return ce_loss, z_loss + + +class CrossEntropyLoss(nn.Module): """ - Cross-entropy loss for language modeling. + Cross-entropy loss. """ def __init__( @@ -28,10 +89,14 @@ def __init__( if compile and fused: log.warning(f"{self.__class__.__name__} with fused+compile is experimental") - self.ignore_index = ignore_index - self.reduction: Literal["mean", "sum", "none"] = reduction - self.z_loss_multiplier = z_loss_multiplier - self.base_loss_fn = fused_cross_entropy_loss if fused else cross_entropy_loss + self._ce_loss = _InnerCELoss( + ignore_index=ignore_index, + reduction=reduction, + z_loss_multiplier=z_loss_multiplier, + fused=fused, + ) + self._tp_enabled = False + if compile: if torch.cuda.is_available(): log.info("Compiling loss function...") @@ -39,6 +104,18 @@ def __init__( else: log.warning("Skipping loss compilation since CUDA is not available") + @property + def tp_enabled(self) -> bool: + return self._tp_enabled + + @property + def z_loss_enabled(self) -> bool: + return self._ce_loss.z_loss_multiplier is not None + + @property + def reduction(self) -> Literal["sum", "mean", "none"]: + return self._ce_loss.reduction + def forward( self, logits: torch.Tensor, @@ -48,42 +125,71 @@ def forward( Compute the CE loss and optionally Z-loss. :param logits: The logits of shape ``(B, S, V)``. - The final logits in the sequence dimension will be discarded to match the shape of - the labels. - :param labels: The target labels of shape ``(B, S-1)``. + :param labels: The target labels of shape ``(B, S)``. """ - if len(logits.shape) != 3: - raise RuntimeError( - f"expected logits to have shape (B, S, V) but found {tuple(logits.shape)} instead" - ) - - B, S, V = logits.shape - if labels.shape != (B, S - 1): - raise RuntimeError( - f"expected labels to have shape (B, S-1) = {(B, S-1)}, but found {tuple(labels.shape)} instead" - ) - - # shape: (B, S - 1, V) - logits_for_loss = logits[..., :-1, :].contiguous() + ce_loss, z_loss = self._ce_loss(get_local_tensor(logits), get_local_tensor(labels)) + + if self.reduction != "none" and ce_loss.numel() > 0: + # This will be the same case with tensor/sequence parallel loss. + if self.reduction == "sum": + ce_loss = ce_loss.sum() + if z_loss is not None: + z_loss = z_loss.sum() + elif self.reduction == "mean": + ce_loss = ce_loss.mean() + if z_loss is not None: + z_loss = z_loss.mean() + else: + raise NotImplementedError(self.reduction) - # shape: (B * (S - 1), V) - logits_for_loss = logits_for_loss.view(-1, V) + return ce_loss, z_loss - # shape: (B, S - 1) -> (B * (S - 1),) - labels_for_loss = labels.view(-1) + def apply_tp( + self, + tp_mesh: DeviceMesh, + input_layouts: Optional[Tuple[Placement, Placement]] = None, + shard_dimension: int = 1, + output_layout: Optional[Placement] = None, + use_local_output: bool = False, + ): + if self.reduction == "none": + raise NotImplementedError(self.reduction) + + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleInput( + input_layouts=input_layouts, # type: ignore + desired_input_layouts=(Shard(shard_dimension), Shard(shard_dimension)), # type: ignore + use_local_output=False, + ), + ) - ce_loss, z_loss = self.base_loss_fn( - logits_for_loss, - labels_for_loss, - ignore_index=self.ignore_index, - reduction=self.reduction, - compute_z_loss=self.z_loss_multiplier is not None, - z_loss_multiplier=self.z_loss_multiplier or 1e-4, + expected_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleOutput( + output_layouts=( # type: ignore + expected_output_layout, + None if self.z_loss_enabled is None else expected_output_layout, + ), + use_local_output=False, + ), ) - if self.reduction == "none": - ce_loss = ce_loss.view(labels.shape) - if z_loss is not None: - z_loss = z_loss.view(labels.shape) + desired_output_layout = output_layout or Replicate() + parallelize_module( + self, + device_mesh=tp_mesh, + parallelize_plan=PrepareModuleOutput( + desired_output_layouts=( # type: ignore + desired_output_layout, + None if self.z_loss_enabled is None else desired_output_layout, + ), + use_local_output=use_local_output, + ), + ) - return ce_loss, z_loss + self._tp_enabled = True + self._ce_loss.tp_enabled = True diff --git a/src/olmo_core/train/common.py b/src/olmo_core/train/common.py index e9c20c4c1..368f90ea2 100644 --- a/src/olmo_core/train/common.py +++ b/src/olmo_core/train/common.py @@ -125,14 +125,10 @@ class ReduceType(StrEnum): def reshape_inputs_for_loss( logits: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # shape: (batch_size, seq_len - 1, vocab_size) - logits_for_loss = logits[..., :-1, :].contiguous() - # shape: (batch_size * (seq_len - 1), vocab_size) - logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) - - # shape: (batch_size, seq_len - 1) -> (batch_size * (seq_len - 1),) + # shape: (B * S, V) + logits_for_loss = logits.view(-1, logits.size(-1)) + # shape: (B, S) -> (B * S,) labels_for_loss = labels.view(-1) - return logits_for_loss, labels_for_loss diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index dff5cf051..4d13cd4e3 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -31,7 +31,7 @@ from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler -from olmo_core.nn.cross_entropy_loss import LMCrossEntropyLoss +from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss from olmo_core.nn.transformer import ( MoETransformer, NormalizedTransformer, @@ -273,14 +273,14 @@ def __init__( log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") self.label_ignore_index = label_ignore_index - self._train_loss_fn = LMCrossEntropyLoss( + self._train_loss_fn = CrossEntropyLoss( ignore_index=label_ignore_index, reduction="sum", z_loss_multiplier=z_loss_multiplier, compile=compile_loss, fused=fused_loss, ) - self._eval_loss_fn = LMCrossEntropyLoss( + self._eval_loss_fn = CrossEntropyLoss( ignore_index=label_ignore_index, reduction="none", compile=compile_loss, diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 129329172..51f98e64a 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -33,7 +33,7 @@ from olmo_core.distributed.utils import get_local_tensor, get_world_size from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler -from olmo_core.nn.cross_entropy_loss import LMCrossEntropyLoss +from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss from olmo_core.nn.transformer import NormalizedTransformer, Transformer from olmo_core.optim import OptimConfig, SkipStepOptimizer from olmo_core.optim.scheduler import Scheduler @@ -319,14 +319,14 @@ def __init__( log.info(f"Data parallel world size = {get_world_size(self.dp_process_group):,d}") self.label_ignore_index = label_ignore_index - self._train_loss_fn = LMCrossEntropyLoss( + self._train_loss_fn = CrossEntropyLoss( ignore_index=label_ignore_index, reduction="sum", z_loss_multiplier=z_loss_multiplier, compile=compile_loss, fused=fused_loss, ) - self._eval_loss_fn = LMCrossEntropyLoss( + self._eval_loss_fn = CrossEntropyLoss( ignore_index=label_ignore_index, reduction="none", compile=compile_loss, diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 36efd8a48..8ca82a872 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -2,10 +2,57 @@ import pytest import torch +from torch.distributed import init_device_mesh +from torch.distributed.tensor import Shard, distribute_tensor -from olmo_core.nn.cross_entropy_loss import LMCrossEntropyLoss +from olmo_core.distributed.utils import get_world_size +from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss +from olmo_core.utils import get_default_device -from ..utils import requires_gpu +from ..distributed.utils import requires_multi_gpu, run_distributed_test + + +def run_cross_entropy_loss_parallel( + fused: bool, + compile: bool, + reduction: Literal["sum", "mean", "none"], + z_loss_multiplier: Optional[float], + logits: torch.Tensor, + labels: torch.Tensor, + batch_num_tokens_for_loss: torch.Tensor, +): + tp_mesh = init_device_mesh("cuda", (get_world_size(),), mesh_dim_names=("tp",)) + + logits = distribute_tensor( + logits.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) + ) + labels = distribute_tensor( + labels.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) + ) + batch_num_tokens_for_loss = batch_num_tokens_for_loss.to(device=get_default_device()) + + loss_fn = CrossEntropyLoss( + reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier + ) + loss_fn.apply_tp(tp_mesh) + + ce_loss, z_loss = loss_fn(logits[..., :-1, :].contiguous(), labels) + ce_loss.div_(batch_num_tokens_for_loss) + if z_loss is not None: + z_loss.div_(batch_num_tokens_for_loss) + + loss = ce_loss + if z_loss is not None: + loss += z_loss + + if reduction != "none": + assert loss.shape == tuple() + else: + assert loss.shape == labels.shape + + # Trigger backward pass. + loss.backward() + assert logits.grad is not None @pytest.mark.parametrize( @@ -15,7 +62,7 @@ pytest.param(False, False, "none", id="default-none"), ], ) -@requires_gpu +@requires_multi_gpu def test_cross_entropy_loss_parallel( fused: bool, compile: bool, @@ -24,18 +71,18 @@ def test_cross_entropy_loss_parallel( ): B, S, V = 4, 16, 256 - loss_fn = LMCrossEntropyLoss( + loss_fn = CrossEntropyLoss( reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier ) - input_ids = torch.randint(0, V, (B, S), device="cuda") + labels = torch.randint(0, V, (B, S), device="cuda") logits = torch.randn(B, S, V, device="cuda", requires_grad=True) - labels = input_ids.clone()[..., 1:].contiguous() labels[0][2] = -100 labels[2][9] = -100 labels[3][12] = -100 - batch_num_tokens_for_loss = (labels != -100).sum() + + # Get losses. ce_loss, z_loss = loss_fn(logits, labels) ce_loss.div_(batch_num_tokens_for_loss) if z_loss is not None: @@ -54,3 +101,19 @@ def test_cross_entropy_loss_parallel( # Trigger backward pass. loss.backward() assert logits.grad is not None + + run_distributed_test( + run_cross_entropy_loss_parallel, + world_size=2, + backend="nccl", + start_method="spawn", + func_args=( + fused, + compile, + reduction, + z_loss_multiplier, + logits.detach().cpu(), + labels.detach().cpu(), + batch_num_tokens_for_loss.detach().cpu(), + ), + ) From cfaf05705f1b69d6642a70700a2cef03e2190269 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:18:19 -0800 Subject: [PATCH 190/230] fix --- src/olmo_core/nn/cross_entropy_loss.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 6257924d6..20601f585 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -152,9 +152,6 @@ def apply_tp( output_layout: Optional[Placement] = None, use_local_output: bool = False, ): - if self.reduction == "none": - raise NotImplementedError(self.reduction) - parallelize_module( self, device_mesh=tp_mesh, @@ -174,6 +171,10 @@ def apply_tp( expected_output_layout, None if self.z_loss_enabled is None else expected_output_layout, ), + desired_output_layouts=( # type: ignore + expected_output_layout, + None if self.z_loss_enabled is None else expected_output_layout, + ), use_local_output=False, ), ) From 761c258e6bce55cf36b1b3cfff4a44382b23f2f4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:21:01 -0800 Subject: [PATCH 191/230] try this --- src/olmo_core/nn/cross_entropy_loss.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 20601f585..f99e4d3f9 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -162,28 +162,33 @@ def apply_tp( ), ) - expected_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) + inner_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleOutput( output_layouts=( # type: ignore - expected_output_layout, - None if self.z_loss_enabled is None else expected_output_layout, + inner_output_layout, + None if self.z_loss_enabled is None else inner_output_layout, ), desired_output_layouts=( # type: ignore - expected_output_layout, - None if self.z_loss_enabled is None else expected_output_layout, + inner_output_layout, + None if self.z_loss_enabled is None else inner_output_layout, ), use_local_output=False, ), ) + expected_output_layout = Shard(shard_dimension) if self.reduction == "none" else Replicate() desired_output_layout = output_layout or Replicate() parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleOutput( + output_layouts=( # type: ignore + expected_output_layout, + None if self.z_loss_enabled is None else expected_output_layout, + ), desired_output_layouts=( # type: ignore desired_output_layout, None if self.z_loss_enabled is None else desired_output_layout, From ae8b27dc41787a579284b617274d2e82c40d6a40 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:23:00 -0800 Subject: [PATCH 192/230] fix? --- src/test/nn/cross_entropy_loss_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 8ca82a872..4a59d78db 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -21,8 +21,10 @@ def run_cross_entropy_loss_parallel( labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor, ): + # Init device mesh. tp_mesh = init_device_mesh("cuda", (get_world_size(),), mesh_dim_names=("tp",)) + # Put tensors on target device and potentially distributed over the device mesh . logits = distribute_tensor( logits.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) ) @@ -31,12 +33,14 @@ def run_cross_entropy_loss_parallel( ) batch_num_tokens_for_loss = batch_num_tokens_for_loss.to(device=get_default_device()) + # Initialize loss and apply parallelism. loss_fn = CrossEntropyLoss( reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier ) loss_fn.apply_tp(tp_mesh) - ce_loss, z_loss = loss_fn(logits[..., :-1, :].contiguous(), labels) + # Get loss tensors. + ce_loss, z_loss = loss_fn(logits, labels) ce_loss.div_(batch_num_tokens_for_loss) if z_loss is not None: z_loss.div_(batch_num_tokens_for_loss) From 75eee3a7166023937a80addf5297f128d78c4dae Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:24:30 -0800 Subject: [PATCH 193/230] fix --- src/olmo_core/nn/cross_entropy_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index f99e4d3f9..ccf9ae08f 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -169,11 +169,11 @@ def apply_tp( parallelize_plan=PrepareModuleOutput( output_layouts=( # type: ignore inner_output_layout, - None if self.z_loss_enabled is None else inner_output_layout, + None if not self.z_loss_enabled else inner_output_layout, ), desired_output_layouts=( # type: ignore inner_output_layout, - None if self.z_loss_enabled is None else inner_output_layout, + None if not self.z_loss_enabled else inner_output_layout, ), use_local_output=False, ), @@ -187,11 +187,11 @@ def apply_tp( parallelize_plan=PrepareModuleOutput( output_layouts=( # type: ignore expected_output_layout, - None if self.z_loss_enabled is None else expected_output_layout, + None if not self.z_loss_enabled else expected_output_layout, ), desired_output_layouts=( # type: ignore desired_output_layout, - None if self.z_loss_enabled is None else desired_output_layout, + None if not self.z_loss_enabled else desired_output_layout, ), use_local_output=use_local_output, ), From a035ee2754ca264fde0a69cae8afc5ad3e523e88 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:25:22 -0800 Subject: [PATCH 194/230] fix --- src/test/nn/cross_entropy_loss_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 4a59d78db..a8f87e8c3 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -26,7 +26,9 @@ def run_cross_entropy_loss_parallel( # Put tensors on target device and potentially distributed over the device mesh . logits = distribute_tensor( - logits.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) + logits.to(device=get_default_device()).requires_grad_(), + device_mesh=tp_mesh, + placements=(Shard(1),), ) labels = distribute_tensor( labels.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) From b1e2db470af85ccd60f68c13315d645ad8bfc6c7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:26:56 -0800 Subject: [PATCH 195/230] fix --- src/test/nn/cross_entropy_loss_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index a8f87e8c3..320f800b3 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -55,6 +55,7 @@ def run_cross_entropy_loss_parallel( assert loss.shape == tuple() else: assert loss.shape == labels.shape + loss = loss.sum() # Trigger backward pass. loss.backward() From 48284141ca9e66b962581bd99ff89a1d7e8cac9b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:30:15 -0800 Subject: [PATCH 196/230] fix --- src/olmo_core/nn/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index ccf9ae08f..914696f27 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -129,7 +129,7 @@ def forward( """ ce_loss, z_loss = self._ce_loss(get_local_tensor(logits), get_local_tensor(labels)) - if self.reduction != "none" and ce_loss.numel() > 0: + if self.reduction != "none" and ce_loss.numel() > 1: # This will be the same case with tensor/sequence parallel loss. if self.reduction == "sum": ce_loss = ce_loss.sum() From 524876cfee68962c9a8306592222f1a1a462e0bd Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:32:14 -0800 Subject: [PATCH 197/230] debug --- src/test/nn/cross_entropy_loss_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 320f800b3..26d850fc8 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -52,7 +52,7 @@ def run_cross_entropy_loss_parallel( loss += z_loss if reduction != "none": - assert loss.shape == tuple() + assert loss.shape == tuple(), f"{loss}" else: assert loss.shape == labels.shape loss = loss.sum() From 00a7b6eb439319a5232174ca9fcf5e16df966a33 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:33:58 -0800 Subject: [PATCH 198/230] fix? --- src/olmo_core/nn/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 914696f27..d47c1cd5e 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -164,7 +164,7 @@ def apply_tp( inner_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) parallelize_module( - self, + self._ce_loss, device_mesh=tp_mesh, parallelize_plan=PrepareModuleOutput( output_layouts=( # type: ignore From 7509318af2e457498df4bd3aa4e4dc604ab9acdd Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:35:56 -0800 Subject: [PATCH 199/230] check for gradients --- src/test/nn/cross_entropy_loss_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 26d850fc8..af777935b 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -20,6 +20,7 @@ def run_cross_entropy_loss_parallel( logits: torch.Tensor, labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor, + grad: torch.Tensor, ): # Init device mesh. tp_mesh = init_device_mesh("cuda", (get_world_size(),), mesh_dim_names=("tp",)) @@ -34,6 +35,9 @@ def run_cross_entropy_loss_parallel( labels.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) ) batch_num_tokens_for_loss = batch_num_tokens_for_loss.to(device=get_default_device()) + grad = distribute_tensor( + grad.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) + ) # Initialize loss and apply parallelism. loss_fn = CrossEntropyLoss( @@ -61,6 +65,9 @@ def run_cross_entropy_loss_parallel( loss.backward() assert logits.grad is not None + # Check gradients. + torch.testing.assert_close(logits.grad, grad) + @pytest.mark.parametrize( "fused, compile, reduction", @@ -122,5 +129,6 @@ def test_cross_entropy_loss_parallel( logits.detach().cpu(), labels.detach().cpu(), batch_num_tokens_for_loss.detach().cpu(), + logits.grad.detach().cpu(), ), ) From 31e6615b275f9a7d449b12b1a34e2a78c491fda3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:40:42 -0800 Subject: [PATCH 200/230] update train modules --- src/olmo_core/nn/transformer/model.py | 2 +- src/olmo_core/train/train_module/transformer.py | 4 +++- src/olmo_core/train/train_module/transformer_pipeline.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 73e54670f..db64d3b63 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -307,7 +307,7 @@ def apply_tp( if self.lm_head is not None: self.lm_head.apply_tp( tp_mesh, - output_layout=Shard(-1) if loss_parallel else Replicate(), + output_layout=Shard(1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 4d13cd4e3..084ef9c77 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -311,8 +311,10 @@ def __init__( self.model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=False, + loss_parallel=True, ) + self._train_loss_fn.apply_tp(tp_mesh) + self._eval_loss_fn.apply_tp(tp_mesh, use_local_output=True) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 51f98e64a..9240ae730 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -370,6 +370,8 @@ def __init__( float8_enabled=float8_enabled, loss_parallel=False, ) + self._train_loss_fn.apply_tp(tp_mesh) + self._eval_loss_fn.apply_tp(tp_mesh, use_local_output=True) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" From 55d60f58ea8a514e50dbbbf9edd72842837533f3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:43:21 -0800 Subject: [PATCH 201/230] fix --- src/olmo_core/train/train_module/transformer.py | 2 +- src/olmo_core/train/train_module/transformer_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 084ef9c77..5eec6f89a 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -520,7 +520,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Batch losses to record. ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) z_batch_loss: Optional[torch.Tensor] = None - if self._train_loss_fn.z_loss_multiplier is not None: + if self._train_loss_fn.z_loss_enabled: z_batch_loss = move_to_device(torch.tensor(0.0), self.device) auxiliary_batch_losses: Dict[str, torch.Tensor] = {} diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 9240ae730..72b2680db 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -661,7 +661,7 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): self.record_ce_loss( self._ce_batch_loss / get_world_size(self.dp_process_group), ReduceType.sum ) - if self._train_loss_fn.z_loss_multiplier is not None: + if self._train_loss_fn.z_loss_enabled: if self._z_batch_loss is None: self.record_metric("Z loss", 0.0, ReduceType.sum, namespace="train") else: From 8c9cac1bf7b5323b526656462e14e089648c5e7a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:46:16 -0800 Subject: [PATCH 202/230] fix? --- src/olmo_core/train/train_module/transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 5eec6f89a..b27b4a2a5 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, Replicate, Shard from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -313,8 +313,10 @@ def __init__( float8_enabled=float8_enabled, loss_parallel=True, ) - self._train_loss_fn.apply_tp(tp_mesh) - self._eval_loss_fn.apply_tp(tp_mesh, use_local_output=True) + self._train_loss_fn.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate())) + self._eval_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + ) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" From db7e42d1bc11cdd3c227ea43a1a7fd82b0964ffe Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:52:45 -0800 Subject: [PATCH 203/230] idk --- src/olmo_core/train/train_module/transformer.py | 4 +++- src/olmo_core/train/train_module/transformer_pipeline.py | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index b27b4a2a5..1622f6512 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -313,7 +313,9 @@ def __init__( float8_enabled=float8_enabled, loss_parallel=True, ) - self._train_loss_fn.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate())) + self._train_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + ) self._eval_loss_fn.apply_tp( tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True ) diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 72b2680db..16297e789 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -13,7 +13,7 @@ from torch.distributed import DeviceMesh from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.pipelining import PipelineStage -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, Replicate, Shard from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -370,8 +370,10 @@ def __init__( float8_enabled=float8_enabled, loss_parallel=False, ) - self._train_loss_fn.apply_tp(tp_mesh) - self._eval_loss_fn.apply_tp(tp_mesh, use_local_output=True) + self._train_loss_fn.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate())) + self._eval_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + ) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" From 747a7ad2cb18a3cc246f1831a2696fafd8ad0fa0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 17:55:43 -0800 Subject: [PATCH 204/230] make loss parallel configurable --- src/olmo_core/train/train_module/transformer.py | 17 ++++++++++------- .../train/train_module/transformer_pipeline.py | 13 ++++++++----- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 1622f6512..7a7b81e5e 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -69,6 +69,8 @@ class TransformerTensorParallelConfig(TensorParallelConfig): Transformer-specific tensor parallel config. """ + loss_parallel: bool = True + @dataclass class TransformerExpertParallelConfig(ExpertParallelConfig): @@ -311,14 +313,15 @@ def __init__( self.model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=True, - ) - self._train_loss_fn.apply_tp( - tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True - ) - self._eval_loss_fn.apply_tp( - tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + loss_parallel=tp_config.loss_parallel, ) + if tp_config.loss_parallel: + self._train_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + ) + self._eval_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + ) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 16297e789..5b54db9fb 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -368,12 +368,15 @@ def __init__( model.apply_tp( tp_mesh, float8_enabled=float8_enabled, - loss_parallel=False, + loss_parallel=tp_config.loss_parallel, + ) + if tp_config.loss_parallel: + self._train_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + ) + self._eval_loss_fn.apply_tp( + tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True ) - self._train_loss_fn.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate())) - self._eval_loss_fn.apply_tp( - tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True - ) tp_config.maybe_enable_async_tp(tp_mesh) log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" From eeb310a42b5ae6162393f0e47f8a2a3946d69bd5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 18:44:08 -0800 Subject: [PATCH 205/230] add long context config --- src/olmo_core/internal/experiment.py | 20 +++- src/scripts/train/OLMo2-7B-long-context.py | 112 +++++++++++++++++++++ 2 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 src/scripts/train/OLMo2-7B-long-context.py diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 643aea9b6..0173ab359 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -136,6 +136,7 @@ def build_common_components( overrides: List[str], *, global_batch_size: int, + sequence_length: int = 4096, ) -> CommonComponents: root_dir = get_root_dir(cluster) @@ -159,10 +160,10 @@ def build_common_components( DataMix.OLMoE_mix_0824, tokenizer=tokenizer_config, mix_base_dir=root_dir, - sequence_length=4096, - max_target_sequence_length=8192, - min_sequence_length=256, - max_sequence_length=8192, + sequence_length=sequence_length, + max_target_sequence_length=max(8192, sequence_length), + min_sequence_length=min(256, sequence_length), + max_sequence_length=max(8192, sequence_length), vsl_curriculum=VSLCurriculumConfig( name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False ), @@ -221,9 +222,16 @@ def build_config( train_module_config_builder: Callable[[CommonComponents], TransformerTrainModuleConfig], trainer_config_builder: Callable[[CommonComponents], TrainerConfig], finalize_config: Optional[Callable[[ExperimentConfig], None]] = None, + sequence_length: int = 4096, ) -> ExperimentConfig: common = build_common_components( - script, cmd, run_name, cluster, overrides, global_batch_size=global_batch_size + script, + cmd, + run_name, + cluster, + overrides, + global_batch_size=global_batch_size, + sequence_length=sequence_length, ) model = model_config_builder(common) @@ -297,6 +305,7 @@ def main( train_module_config_builder: Callable[[CommonComponents], TransformerTrainModuleConfig], trainer_config_builder: Callable[[CommonComponents], TrainerConfig], finalize_config: Optional[Callable[[ExperimentConfig], None]] = None, + sequence_length: int = 4096, ): usage = f""" [yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]RUN_NAME CLUSTER[/] [i][OVERRIDES...][/] @@ -336,6 +345,7 @@ def main( train_module_config_builder=train_module_config_builder, trainer_config_builder=trainer_config_builder, finalize_config=finalize_config, + sequence_length=sequence_length, ) cmd.run(config) diff --git a/src/scripts/train/OLMo2-7B-long-context.py b/src/scripts/train/OLMo2-7B-long-context.py new file mode 100644 index 000000000..0a23eedf2 --- /dev/null +++ b/src/scripts/train/OLMo2-7B-long-context.py @@ -0,0 +1,112 @@ +""" +Train a 7B OLMo model on long contexts. Run this script without any arguments to see usage info. +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.float8 import Float8Config +from olmo_core.internal.experiment import CommonComponents, main +from olmo_core.nn.transformer import TransformerConfig +from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride +from olmo_core.train import TrainerConfig +from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback +from olmo_core.train.train_module import ( + TransformerActivationCheckpointingConfig, + TransformerDataParallelConfig, + TransformerDataParallelWrappingStrategy, + TransformerTensorParallelConfig, + TransformerTrainModuleConfig, +) + +log = logging.getLogger(__name__) + + +CONTEXT_LENGTH = 16_384 + + +def build_model_config(common: CommonComponents) -> TransformerConfig: + return TransformerConfig.olmo2_7B(vocab_size=common.tokenizer.padded_vocab_size()) + + +def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: + return TransformerTrainModuleConfig( + rank_microbatch_size=1 * CONTEXT_LENGTH, + max_sequence_length=common.dataset.effective_sequence_length, + optim=AdamWConfig( + lr=3e-5, + weight_decay=0.1, + betas=(0.9, 0.95), + group_overrides=[ + OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) + ], + fused=True, + ), + compile_model=True, + dp_config=TransformerDataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + wrapping_strategy=TransformerDataParallelWrappingStrategy.fine_grained, + ), + tp_config=TransformerTensorParallelConfig( + degree=2, + loss_parallel=True, + ), + ac_config=TransformerActivationCheckpointingConfig(), + float8_config=Float8Config(enabled=True), + z_loss_multiplier=1e-5, + compile_loss=True, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=2000), + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + return ( + TrainerConfig( + save_folder=common.save_folder, + save_overwrite=True, + metrics_collect_interval=10, + cancel_check_interval=1, + ) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=10_000, + ephemeral_save_interval=250, + save_async=True, + ), + ) + .with_callback( + "comet", + CometCallback( + name=common.run_name, + workspace="ai2", + project="OLMo-core-7B", + enabled=True, + cancel_check_interval=10, + ), + ) + .with_callback( + "wandb", + WandBCallback( + name=common.run_name, + entity="ai2-llm", + project="OLMo-core-7B", + enabled=False, + cancel_check_interval=10, + ), + ) + ) + + +if __name__ == "__main__": + main( + sequence_length=CONTEXT_LENGTH, + global_batch_size=64 * CONTEXT_LENGTH, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + trainer_config_builder=build_trainer_config, + ) From 253786452f35d152365e359525f2bf798014935c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 18:54:44 -0800 Subject: [PATCH 206/230] increase context length --- src/scripts/train/OLMo2-7B-long-context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scripts/train/OLMo2-7B-long-context.py b/src/scripts/train/OLMo2-7B-long-context.py index 0a23eedf2..5ee05c767 100644 --- a/src/scripts/train/OLMo2-7B-long-context.py +++ b/src/scripts/train/OLMo2-7B-long-context.py @@ -23,7 +23,7 @@ log = logging.getLogger(__name__) -CONTEXT_LENGTH = 16_384 +CONTEXT_LENGTH = 2 * 16_384 def build_model_config(common: CommonComponents) -> TransformerConfig: From e8db8bada9c778795fe0cc57322c7b7596d036a2 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 8 Feb 2025 18:58:13 -0800 Subject: [PATCH 207/230] try 64 --- src/scripts/train/OLMo2-7B-long-context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scripts/train/OLMo2-7B-long-context.py b/src/scripts/train/OLMo2-7B-long-context.py index 5ee05c767..f2438d063 100644 --- a/src/scripts/train/OLMo2-7B-long-context.py +++ b/src/scripts/train/OLMo2-7B-long-context.py @@ -23,7 +23,7 @@ log = logging.getLogger(__name__) -CONTEXT_LENGTH = 2 * 16_384 +CONTEXT_LENGTH = 4 * 16_384 def build_model_config(common: CommonComponents) -> TransformerConfig: From ae584520ba7a8aebe47ad3ced70facfd6e03c622 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 9 Feb 2025 21:54:37 -0800 Subject: [PATCH 208/230] rename some things --- src/olmo_core/nn/cross_entropy_loss.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index d47c1cd5e..419d78b34 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -18,7 +18,7 @@ log = logging.getLogger(__name__) -class _InnerCELoss(nn.Module): +class _CELossFnWrapper(nn.Module): def __init__( self, ignore_index: int = -100, @@ -89,7 +89,7 @@ def __init__( if compile and fused: log.warning(f"{self.__class__.__name__} with fused+compile is experimental") - self._ce_loss = _InnerCELoss( + self.loss_fn = _CELossFnWrapper( ignore_index=ignore_index, reduction=reduction, z_loss_multiplier=z_loss_multiplier, @@ -110,11 +110,11 @@ def tp_enabled(self) -> bool: @property def z_loss_enabled(self) -> bool: - return self._ce_loss.z_loss_multiplier is not None + return self.loss_fn.z_loss_multiplier is not None @property def reduction(self) -> Literal["sum", "mean", "none"]: - return self._ce_loss.reduction + return self.loss_fn.reduction def forward( self, @@ -127,7 +127,7 @@ def forward( :param logits: The logits of shape ``(B, S, V)``. :param labels: The target labels of shape ``(B, S)``. """ - ce_loss, z_loss = self._ce_loss(get_local_tensor(logits), get_local_tensor(labels)) + ce_loss, z_loss = self.loss_fn(get_local_tensor(logits), get_local_tensor(labels)) if self.reduction != "none" and ce_loss.numel() > 1: # This will be the same case with tensor/sequence parallel loss. @@ -164,7 +164,7 @@ def apply_tp( inner_output_layout = Shard(shard_dimension) if self.reduction == "none" else Shard(0) parallelize_module( - self._ce_loss, + self.loss_fn, device_mesh=tp_mesh, parallelize_plan=PrepareModuleOutput( output_layouts=( # type: ignore @@ -198,4 +198,4 @@ def apply_tp( ) self._tp_enabled = True - self._ce_loss.tp_enabled = True + self.loss_fn.tp_enabled = True From 91679ef1bbb92c8dfcf882e6b9ee56cc606f65a9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 08:38:35 -0800 Subject: [PATCH 209/230] back to non moe --- src/examples/llama/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/examples/llama/train.py b/src/examples/llama/train.py index 171a51a40..d2736f876 100644 --- a/src/examples/llama/train.py +++ b/src/examples/llama/train.py @@ -56,8 +56,8 @@ class ExperimentConfig(Config): def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: tokenizer_config = TokenizerConfig.gpt2() - # model_config = TransformerConfig.llama2_271M( - model_config = TransformerConfig.smallmoe( + model_config = TransformerConfig.llama2_271M( + # model_config = TransformerConfig.smallmoe( vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128 ) From 9e426d54ac432e1b224dbff02394adb28bdb4901 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 08:46:27 -0800 Subject: [PATCH 210/230] fix? --- src/olmo_core/eval/lm_evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/eval/lm_evaluator.py b/src/olmo_core/eval/lm_evaluator.py index d9d7ca66e..5663704c9 100644 --- a/src/olmo_core/eval/lm_evaluator.py +++ b/src/olmo_core/eval/lm_evaluator.py @@ -96,7 +96,7 @@ def update_metrics( for idx, (metadata, tokens_loss) in enumerate(zip(batch["metadata"], ce_loss)): metric = self.metrics[metadata["label"]] if "label_mask" in batch: - tokens_loss = tokens_loss.masked_select(batch["label_mask"][idx][1:]) + tokens_loss = tokens_loss.masked_select(batch["label_mask"][idx]) metric.update(tokens_loss) def compute_metrics(self) -> Dict[str, torch.Tensor]: From 5ea11ea477f08d573fa15ff7612463ceb30f7582 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 09:31:00 -0800 Subject: [PATCH 211/230] fix? --- src/olmo_core/distributed/utils.py | 7 +++++++ src/olmo_core/train/train_module/transformer.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index fc3f2e101..b9ce8e99b 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -424,6 +424,13 @@ def get_local_tensor(x: torch.Tensor) -> torch.Tensor: return x +def get_full_tensor(x: torch.Tensor) -> torch.Tensor: + if isinstance(x, DTensor): + return x.full_tensor() + else: + return x + + def do_n_at_a_time( f: Callable[[], T], *, diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 7a7b81e5e..e5abdb2b5 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -27,7 +27,11 @@ get_ep_mesh, get_tp_mesh, ) -from olmo_core.distributed.utils import get_local_tensor, get_world_size +from olmo_core.distributed.utils import ( + get_full_tensor, + get_local_tensor, + get_world_size, +) from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler @@ -612,6 +616,8 @@ def eval_batch( loss: Optional[torch.Tensor] = None if labels is not None: loss = self.eval_loss_fn(logits, labels) + loss = get_full_tensor(loss) + logits = get_full_tensor(logits) return logits, loss From 92538ca2ced31f2f59ce53288378c343e0494221 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 09:37:38 -0800 Subject: [PATCH 212/230] fixed eval sequence length when TP enabled --- src/olmo_core/train/train_module/transformer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index e5abdb2b5..e591ef00d 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -310,6 +310,7 @@ def __init__( log.info("Swapped linear layers to Float8 linear layers") # Maybe apply tensor/expert parallelism. + self._tp_enabled = False if tp_config is not None and ep_config is not None: raise NotImplementedError("TP + EP is not implemented yet") if tp_config is not None: @@ -330,12 +331,16 @@ def __init__( log.info( f"Applied {'Float8 ' if float8_enabled else ''}tensor parallelism to the model" ) + self._tp_enabled = True + + self._ep_enabled = False if ep_config is not None: if not self.model.is_moe: raise OLMoConfigurationError("Expert parallelism is only valid for MoE models") ep_mesh = get_ep_mesh(self.world_mesh) cast(MoETransformer, self.model).apply_ep(ep_mesh) log.info("Applied expert parallelism to the model") + self._ep_enabled = True # Maybe apply activation checkpointing. if ac_config is not None: @@ -407,9 +412,19 @@ def dp_process_group(self) -> Optional[dist.ProcessGroup]: @property def eval_batch_spec(self) -> EvalBatchSpec: return EvalBatchSpec( - self.rank_microbatch_size, max_sequence_length=self.max_sequence_length + self.rank_microbatch_size, + max_sequence_length=self.max_sequence_length, + fixed_sequence_length=self.tp_enabled, ) + @property + def tp_enabled(self) -> bool: + return self._tp_enabled + + @property + def ep_enabled(self) -> bool: + return self._ep_enabled + def loss_fn( self, logits: torch.Tensor, labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: From 2bfafde85f8aefeb9686cf35e6a8f6cb647ed38c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 09:54:16 -0800 Subject: [PATCH 213/230] fix eval batch size --- .../train/callbacks/evaluator_callback.py | 15 ++++++++++++++- .../train/train_module/transformer_pipeline.py | 8 +++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index 55dba1f81..63d76cb56 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -250,7 +250,20 @@ def __init__( if batch_spec.batch_size_unit == EvalBatchSizeUnit.instances: rank_batch_size_instances = batch_spec.rank_batch_size elif batch_spec.batch_size_unit == EvalBatchSizeUnit.tokens: - rank_batch_size_instances = batch_spec.rank_batch_size // self.task.max_sequence_length + if batch_spec.fixed_sequence_length: + assert batch_spec.max_sequence_length is not None + if batch_spec.rank_batch_size % batch_spec.max_sequence_length != 0: + raise OLMoConfigurationError( + f"The eval batch size ({batch_spec.rank_batch_size} tokens) must be divisible " + f"by the maximum eval sequence length ({batch_spec.max_sequence_length:,d} tokens)" + ) + rank_batch_size_instances = ( + batch_spec.rank_batch_size // batch_spec.max_sequence_length + ) + else: + rank_batch_size_instances = ( + batch_spec.rank_batch_size // self.task.max_sequence_length + ) else: raise NotImplementedError(batch_spec.batch_size_unit) diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 5b54db9fb..8f2fa06e1 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -30,7 +30,11 @@ get_pp_mesh, get_tp_mesh, ) -from olmo_core.distributed.utils import get_local_tensor, get_world_size +from olmo_core.distributed.utils import ( + get_full_tensor, + get_local_tensor, + get_world_size, +) from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config, Float8Handler from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss @@ -701,6 +705,8 @@ def eval_batch( assert logits is not None if labels is not None: assert loss is not None + loss = get_full_tensor(loss) + logits = get_full_tensor(logits) self._clear_loss_buffers() From 2b2d838fb77f2350388b4ea9545450d75ef15779 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 10:28:16 -0800 Subject: [PATCH 214/230] revert the FSL --- src/olmo_core/train/train_module/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index e591ef00d..489761dd0 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -414,7 +414,7 @@ def eval_batch_spec(self) -> EvalBatchSpec: return EvalBatchSpec( self.rank_microbatch_size, max_sequence_length=self.max_sequence_length, - fixed_sequence_length=self.tp_enabled, + # fixed_sequence_length=self.tp_enabled, ) @property From 7b9488f37fd21c68a59e4b9b44d41777cb849119 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:20:59 -0800 Subject: [PATCH 215/230] upgrade olmo-eval --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 654087b77..eb27a09e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", - "ai2-olmo-eval==0.5.0", + "ai2-olmo-eval==0.6.1", ] [project.urls] From 0c7cadc911de66801bff6a2d5c85d8f65982341a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:25:02 -0800 Subject: [PATCH 216/230] update test --- src/test/nn/cross_entropy_loss_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index af777935b..0cc6fb782 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -21,6 +21,7 @@ def run_cross_entropy_loss_parallel( labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor, grad: torch.Tensor, + loss: torch.Tensor, ): # Init device mesh. tp_mesh = init_device_mesh("cuda", (get_world_size(),), mesh_dim_names=("tp",)) @@ -38,12 +39,13 @@ def run_cross_entropy_loss_parallel( grad = distribute_tensor( grad.to(device=get_default_device()), device_mesh=tp_mesh, placements=(Shard(1),) ) + loss = loss.to(device=get_default_device()) # Initialize loss and apply parallelism. loss_fn = CrossEntropyLoss( reduction=reduction, compile=compile, fused=fused, z_loss_multiplier=z_loss_multiplier ) - loss_fn.apply_tp(tp_mesh) + loss_fn.apply_tp(tp_mesh, use_local_output=True) # Get loss tensors. ce_loss, z_loss = loss_fn(logits, labels) @@ -67,6 +69,7 @@ def run_cross_entropy_loss_parallel( # Check gradients. torch.testing.assert_close(logits.grad, grad) + torch.testing.assert_close(loss.detach(), loss) @pytest.mark.parametrize( @@ -130,5 +133,6 @@ def test_cross_entropy_loss_parallel( labels.detach().cpu(), batch_num_tokens_for_loss.detach().cpu(), logits.grad.detach().cpu(), + loss.detach().cpu(), ), ) From 95ca989d86b8e3acd28cc209369b8db34e40a65e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:31:27 -0800 Subject: [PATCH 217/230] fix? --- src/olmo_core/nn/cross_entropy_loss.py | 3 ++- src/test/nn/cross_entropy_loss_test.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 419d78b34..51d82c6fe 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -130,7 +130,8 @@ def forward( ce_loss, z_loss = self.loss_fn(get_local_tensor(logits), get_local_tensor(labels)) if self.reduction != "none" and ce_loss.numel() > 1: - # This will be the same case with tensor/sequence parallel loss. + # This will be the same case with tensor/sequence parallel loss where we have a DTensor. + assert self.tp_enabled if self.reduction == "sum": ce_loss = ce_loss.sum() if z_loss is not None: diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 0cc6fb782..407f6c2ff 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -49,9 +49,9 @@ def run_cross_entropy_loss_parallel( # Get loss tensors. ce_loss, z_loss = loss_fn(logits, labels) - ce_loss.div_(batch_num_tokens_for_loss) + ce_loss = ce_loss.div(batch_num_tokens_for_loss) if z_loss is not None: - z_loss.div_(batch_num_tokens_for_loss) + z_loss = z_loss.div(batch_num_tokens_for_loss) loss = ce_loss if z_loss is not None: From be2b15b65f3fae42935412bc72492bc39c5e5843 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:34:24 -0800 Subject: [PATCH 218/230] check loss first --- src/test/nn/cross_entropy_loss_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 407f6c2ff..d2e01f095 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -63,13 +63,15 @@ def run_cross_entropy_loss_parallel( assert loss.shape == labels.shape loss = loss.sum() + # Check loss. + torch.testing.assert_close(loss.detach(), loss) + # Trigger backward pass. loss.backward() assert logits.grad is not None # Check gradients. torch.testing.assert_close(logits.grad, grad) - torch.testing.assert_close(loss.detach(), loss) @pytest.mark.parametrize( From cd818fdc86c1d4cc6fa951a79cb66f03d825804c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:39:54 -0800 Subject: [PATCH 219/230] compare local tensor --- src/test/nn/cross_entropy_loss_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index d2e01f095..81e9e6c1b 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -5,7 +5,7 @@ from torch.distributed import init_device_mesh from torch.distributed.tensor import Shard, distribute_tensor -from olmo_core.distributed.utils import get_world_size +from olmo_core.distributed.utils import get_local_tensor, get_world_size from olmo_core.nn.cross_entropy_loss import CrossEntropyLoss from olmo_core.utils import get_default_device @@ -71,7 +71,7 @@ def run_cross_entropy_loss_parallel( assert logits.grad is not None # Check gradients. - torch.testing.assert_close(logits.grad, grad) + torch.testing.assert_close(get_local_tensor(logits.grad), get_local_tensor(grad)) @pytest.mark.parametrize( From 0ec327c04ee9508b88ceee8a50ba8671f393fafb Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:48:32 -0800 Subject: [PATCH 220/230] clean up --- src/test/nn/cross_entropy_loss_test.py | 58 +++++++++++++------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index 81e9e6c1b..b0b488cec 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -12,6 +12,30 @@ from ..distributed.utils import requires_multi_gpu, run_distributed_test +def compute_loss( + loss_fn: CrossEntropyLoss, + logits: torch.Tensor, + labels: torch.Tensor, + batch_num_tokens_for_loss: torch.Tensor, +) -> torch.Tensor: + ce_loss, z_loss = loss_fn(logits, labels) + ce_loss = ce_loss / batch_num_tokens_for_loss + if z_loss is not None: + z_loss = z_loss / batch_num_tokens_for_loss + + loss = ce_loss + if z_loss is not None: + loss += z_loss + + if loss_fn.reduction != "none": + assert loss.shape == tuple(), f"{loss}" + else: + assert loss.shape == labels.shape + loss = loss.sum() + + return loss + + def run_cross_entropy_loss_parallel( fused: bool, compile: bool, @@ -47,21 +71,8 @@ def run_cross_entropy_loss_parallel( ) loss_fn.apply_tp(tp_mesh, use_local_output=True) - # Get loss tensors. - ce_loss, z_loss = loss_fn(logits, labels) - ce_loss = ce_loss.div(batch_num_tokens_for_loss) - if z_loss is not None: - z_loss = z_loss.div(batch_num_tokens_for_loss) - - loss = ce_loss - if z_loss is not None: - loss += z_loss - - if reduction != "none": - assert loss.shape == tuple(), f"{loss}" - else: - assert loss.shape == labels.shape - loss = loss.sum() + # Get loss. + loss = compute_loss(loss_fn, logits, labels, batch_num_tokens_for_loss) # Check loss. torch.testing.assert_close(loss.detach(), loss) @@ -101,21 +112,8 @@ def test_cross_entropy_loss_parallel( labels[3][12] = -100 batch_num_tokens_for_loss = (labels != -100).sum() - # Get losses. - ce_loss, z_loss = loss_fn(logits, labels) - ce_loss.div_(batch_num_tokens_for_loss) - if z_loss is not None: - z_loss.div_(batch_num_tokens_for_loss) - - loss = ce_loss - if z_loss is not None: - loss += z_loss - - if reduction != "none": - assert loss.shape == tuple() - else: - assert loss.shape == labels.shape - loss = loss.sum() + # Get loss. + loss = compute_loss(loss_fn, logits, labels, batch_num_tokens_for_loss) # Trigger backward pass. loss.backward() From d87a5d3e001ab8262ca201582c4c313c37987b6e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:55:53 -0800 Subject: [PATCH 221/230] try this --- src/olmo_core/nn/cross_entropy_loss.py | 10 ++++++++-- src/olmo_core/train/train_module/transformer.py | 8 ++++++-- src/test/nn/cross_entropy_loss_test.py | 8 ++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 51d82c6fe..636a41c5a 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -120,6 +120,7 @@ def forward( self, logits: torch.Tensor, labels: torch.Tensor, + div_factor: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Compute the CE loss and optionally Z-loss. @@ -143,12 +144,17 @@ def forward( else: raise NotImplementedError(self.reduction) + if div_factor is not None: + ce_loss = ce_loss / div_factor + if z_loss is not None: + z_loss = z_loss / div_factor + return ce_loss, z_loss def apply_tp( self, tp_mesh: DeviceMesh, - input_layouts: Optional[Tuple[Placement, Placement]] = None, + input_layouts: Optional[Tuple[Placement, Placement, Placement]] = None, shard_dimension: int = 1, output_layout: Optional[Placement] = None, use_local_output: bool = False, @@ -158,7 +164,7 @@ def apply_tp( device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=input_layouts, # type: ignore - desired_input_layouts=(Shard(shard_dimension), Shard(shard_dimension)), # type: ignore + desired_input_layouts=(Shard(shard_dimension), Shard(shard_dimension), Replicate()), # type: ignore use_local_output=False, ), ) diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 489761dd0..370db8f1a 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -322,10 +322,14 @@ def __init__( ) if tp_config.loss_parallel: self._train_loss_fn.apply_tp( - tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + tp_mesh, + input_layouts=(Shard(1), Replicate(), Replicate()), + use_local_output=True, ) self._eval_loss_fn.apply_tp( - tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + tp_mesh, + input_layouts=(Shard(1), Replicate(), Replicate()), + use_local_output=True, ) tp_config.maybe_enable_async_tp(tp_mesh) log.info( diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index b0b488cec..e37d26f85 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -18,10 +18,10 @@ def compute_loss( labels: torch.Tensor, batch_num_tokens_for_loss: torch.Tensor, ) -> torch.Tensor: - ce_loss, z_loss = loss_fn(logits, labels) - ce_loss = ce_loss / batch_num_tokens_for_loss - if z_loss is not None: - z_loss = z_loss / batch_num_tokens_for_loss + ce_loss, z_loss = loss_fn(logits, labels, batch_num_tokens_for_loss) + # ce_loss = ce_loss / batch_num_tokens_for_loss + # if z_loss is not None: + # z_loss = z_loss / batch_num_tokens_for_loss loss = ce_loss if z_loss is not None: From 2ea5307d8d8bddab1367c1123eeab678500e5004 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 11:58:21 -0800 Subject: [PATCH 222/230] fix? --- src/olmo_core/nn/cross_entropy_loss.py | 2 +- src/olmo_core/train/train_module/transformer.py | 8 ++------ .../train/train_module/transformer_pipeline.py | 10 ++++------ src/test/nn/cross_entropy_loss_test.py | 3 --- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 636a41c5a..2929c736d 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -154,7 +154,7 @@ def forward( def apply_tp( self, tp_mesh: DeviceMesh, - input_layouts: Optional[Tuple[Placement, Placement, Placement]] = None, + input_layouts: Optional[Tuple[Placement, ...]] = None, shard_dimension: int = 1, output_layout: Optional[Placement] = None, use_local_output: bool = False, diff --git a/src/olmo_core/train/train_module/transformer.py b/src/olmo_core/train/train_module/transformer.py index 370db8f1a..4cadd43fa 100644 --- a/src/olmo_core/train/train_module/transformer.py +++ b/src/olmo_core/train/train_module/transformer.py @@ -328,7 +328,7 @@ def __init__( ) self._eval_loss_fn.apply_tp( tp_mesh, - input_layouts=(Shard(1), Replicate(), Replicate()), + input_layouts=(Shard(1), Replicate()), use_local_output=True, ) tp_config.maybe_enable_async_tp(tp_mesh) @@ -435,11 +435,7 @@ def loss_fn( # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' # (the total number of tokens used in the loss across the whole batch, not just the micro batch) # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss = self._train_loss_fn(logits, labels) - - ce_loss.div_(batch_num_tokens_for_loss) - if z_loss is not None: - z_loss.div_(batch_num_tokens_for_loss) + ce_loss, z_loss = self._train_loss_fn(logits, labels, batch_num_tokens_for_loss) # Get loss to optimize for. loss = ce_loss diff --git a/src/olmo_core/train/train_module/transformer_pipeline.py b/src/olmo_core/train/train_module/transformer_pipeline.py index 8f2fa06e1..030359de1 100644 --- a/src/olmo_core/train/train_module/transformer_pipeline.py +++ b/src/olmo_core/train/train_module/transformer_pipeline.py @@ -376,7 +376,9 @@ def __init__( ) if tp_config.loss_parallel: self._train_loss_fn.apply_tp( - tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True + tp_mesh, + input_layouts=(Shard(1), Replicate(), Replicate()), + use_local_output=True, ) self._eval_loss_fn.apply_tp( tp_mesh, input_layouts=(Shard(1), Replicate()), use_local_output=True @@ -503,11 +505,7 @@ def loss_fn(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' # (the total number of tokens used in the loss across the whole batch, not just the micro batch) # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss = self._train_loss_fn(logits, labels) - - ce_loss.div_(self._batch_num_tokens_for_loss) - if z_loss is not None: - z_loss.div_(self._batch_num_tokens_for_loss) + ce_loss, z_loss = self._train_loss_fn(logits, labels, self._batch_num_tokens_for_loss) # Get loss to optimize for. loss = ce_loss diff --git a/src/test/nn/cross_entropy_loss_test.py b/src/test/nn/cross_entropy_loss_test.py index e37d26f85..2abb5b286 100644 --- a/src/test/nn/cross_entropy_loss_test.py +++ b/src/test/nn/cross_entropy_loss_test.py @@ -19,9 +19,6 @@ def compute_loss( batch_num_tokens_for_loss: torch.Tensor, ) -> torch.Tensor: ce_loss, z_loss = loss_fn(logits, labels, batch_num_tokens_for_loss) - # ce_loss = ce_loss / batch_num_tokens_for_loss - # if z_loss is not None: - # z_loss = z_loss / batch_num_tokens_for_loss loss = ce_loss if z_loss is not None: From c55d7332c5edc8a31a2c79af32fac58a63cdd0a7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 12:01:10 -0800 Subject: [PATCH 223/230] fix --- src/olmo_core/nn/cross_entropy_loss.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/nn/cross_entropy_loss.py b/src/olmo_core/nn/cross_entropy_loss.py index 2929c736d..96c576761 100644 --- a/src/olmo_core/nn/cross_entropy_loss.py +++ b/src/olmo_core/nn/cross_entropy_loss.py @@ -159,12 +159,19 @@ def apply_tp( output_layout: Optional[Placement] = None, use_local_output: bool = False, ): + desired_input_layouts: Tuple[Placement, ...] + if input_layouts is None or len(input_layouts) == 3: + desired_input_layouts = (Shard(shard_dimension), Shard(shard_dimension), Replicate()) + elif len(input_layouts) == 2: + desired_input_layouts = (Shard(shard_dimension), Shard(shard_dimension)) + else: + raise ValueError(f"expected 2 or 3 input layouts, found {len(input_layouts)}") parallelize_module( self, device_mesh=tp_mesh, parallelize_plan=PrepareModuleInput( input_layouts=input_layouts, # type: ignore - desired_input_layouts=(Shard(shard_dimension), Shard(shard_dimension), Replicate()), # type: ignore + desired_input_layouts=desired_input_layouts, # type: ignore use_local_output=False, ), ) From f4ec2e128f0f32de931007dc86c4db279f91cdd9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 12:22:34 -0800 Subject: [PATCH 224/230] exclude default evals from long context config --- src/olmo_core/internal/experiment.py | 24 ++++++++++++++-------- src/scripts/train/OLMo2-7B-long-context.py | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 0173ab359..c191e4fce 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -137,6 +137,7 @@ def build_common_components( *, global_batch_size: int, sequence_length: int = 4096, + include_default_evals: bool = True, ) -> CommonComponents: root_dir = get_root_dir(cluster) @@ -178,7 +179,14 @@ def build_common_components( "config_saver": ConfigSaverCallback(), "profiler": ProfilerCallback(enabled=False), "garbage_collector": GarbageCollectorCallback(), - "lm_evaluator": LMEvaluatorCallbackConfig( + "slack_notifier": SlackNotifierCallback(name=run_name, enabled=False), + } + + if torch.cuda.is_available(): + callbacks["gpu_monitor"] = GPUMemoryMonitorCallback() + + if include_default_evals: + callbacks["lm_evaluator"] = LMEvaluatorCallbackConfig( eval_dataset=NumpyDatasetConfig.from_data_mix( DataMix.v3_small_ppl_validation, name=NumpyDatasetType.padded_fsl, @@ -188,16 +196,12 @@ def build_common_components( work_dir=get_work_dir(root_dir), ), eval_interval=1000, - ), - "downstream_evaluator": DownstreamEvaluatorCallbackConfig( + ) + callbacks["downstream_evaluator"] = DownstreamEvaluatorCallbackConfig( tasks=["hellaswag"], tokenizer=tokenizer_config, eval_interval=1000, - ), - "slack_notifier": SlackNotifierCallback(name=run_name, enabled=False), - } - if torch.cuda.is_available(): - callbacks["gpu_monitor"] = GPUMemoryMonitorCallback() + ) return CommonComponents( run_name=run_name, @@ -223,6 +227,7 @@ def build_config( trainer_config_builder: Callable[[CommonComponents], TrainerConfig], finalize_config: Optional[Callable[[ExperimentConfig], None]] = None, sequence_length: int = 4096, + include_default_evals: bool = True, ) -> ExperimentConfig: common = build_common_components( script, @@ -232,6 +237,7 @@ def build_config( overrides, global_batch_size=global_batch_size, sequence_length=sequence_length, + include_default_evals=include_default_evals, ) model = model_config_builder(common) @@ -306,6 +312,7 @@ def main( trainer_config_builder: Callable[[CommonComponents], TrainerConfig], finalize_config: Optional[Callable[[ExperimentConfig], None]] = None, sequence_length: int = 4096, + include_default_evals: bool = True, ): usage = f""" [yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]RUN_NAME CLUSTER[/] [i][OVERRIDES...][/] @@ -346,6 +353,7 @@ def main( trainer_config_builder=trainer_config_builder, finalize_config=finalize_config, sequence_length=sequence_length, + include_default_evals=include_default_evals, ) cmd.run(config) diff --git a/src/scripts/train/OLMo2-7B-long-context.py b/src/scripts/train/OLMo2-7B-long-context.py index f2438d063..7e2ea5aaf 100644 --- a/src/scripts/train/OLMo2-7B-long-context.py +++ b/src/scripts/train/OLMo2-7B-long-context.py @@ -109,4 +109,5 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: model_config_builder=build_model_config, train_module_config_builder=build_train_module_config, trainer_config_builder=build_trainer_config, + include_default_evals=False, ) From 77d3d7121a2e0fef6d769f323d4296f9f81cb691 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 12:30:57 -0800 Subject: [PATCH 225/230] okay try without that --- src/olmo_core/nn/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/nn/utils.py b/src/olmo_core/nn/utils.py index 68ffc06b0..d8821cd8f 100644 --- a/src/olmo_core/nn/utils.py +++ b/src/olmo_core/nn/utils.py @@ -59,14 +59,19 @@ def get_tp_wrappers( else: # TODO (epwalsh): once float8 configuration supports delayed scaling, # add a check here to enforce supported float8 all-gather configurations. + # from torchao.float8.float8_tensor_parallel import ( # type: ignore + # Float8ColwiseParallel, + # Float8RowwiseParallel, + # PrepareFloat8ModuleInput, + # ) from torchao.float8.float8_tensor_parallel import ( # type: ignore Float8ColwiseParallel, Float8RowwiseParallel, - PrepareFloat8ModuleInput, ) return ( Float8RowwiseParallel, Float8ColwiseParallel, - PrepareFloat8ModuleInput, + # PrepareFloat8ModuleInput, + PrepareModuleInput, ) From c84adef24e30c620609a527267e315d4a25a61b0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 12:36:15 -0800 Subject: [PATCH 226/230] revert --- src/olmo_core/nn/utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/nn/utils.py b/src/olmo_core/nn/utils.py index d8821cd8f..68ffc06b0 100644 --- a/src/olmo_core/nn/utils.py +++ b/src/olmo_core/nn/utils.py @@ -59,19 +59,14 @@ def get_tp_wrappers( else: # TODO (epwalsh): once float8 configuration supports delayed scaling, # add a check here to enforce supported float8 all-gather configurations. - # from torchao.float8.float8_tensor_parallel import ( # type: ignore - # Float8ColwiseParallel, - # Float8RowwiseParallel, - # PrepareFloat8ModuleInput, - # ) from torchao.float8.float8_tensor_parallel import ( # type: ignore Float8ColwiseParallel, Float8RowwiseParallel, + PrepareFloat8ModuleInput, ) return ( Float8RowwiseParallel, Float8ColwiseParallel, - # PrepareFloat8ModuleInput, - PrepareModuleInput, + PrepareFloat8ModuleInput, ) From 53c1f6c9685b565ff538b67a0794572b8b3a73c5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 12:46:24 -0800 Subject: [PATCH 227/230] update install instructions --- README.md | 2 +- docs/source/overview/installation.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 926a2a764..e58e8fa0f 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ pip install ai2-olmo-core There are a number of optional dependencies that must be installed to use certain functionality as well, including: - [flash-attn](https://github.com/Dao-AILab/flash-attention) for flash attention and certain other fused operations. - [torchao](https://github.com/pytorch/ao) for float8 training. -- [grouped_gemm](https://github.com/tgale96/grouped_gemm) for mixture-of-experts (MoE) models. You may need to compile from source until [PR #21](https://github.com/tgale96/grouped_gemm/pull/21) is released (post v0.1.6). +- [grouped_gemm](https://github.com/tgale96/grouped_gemm) for dropless mixture-of-experts (MoE) models. You may need to compile from source until [PR #21](https://github.com/tgale96/grouped_gemm/pull/21) is released (post v0.1.6). The published [Docker images](https://github.com/orgs/allenai/packages?repo_name=OLMo-core) contain all core and optional dependencies, and are regularly tested on our in-house H100 clusters. But there are several things to keep in mind if you intend to use these images: diff --git a/docs/source/overview/installation.rst b/docs/source/overview/installation.rst index 3b47e9e5f..ffe4b5218 100644 --- a/docs/source/overview/installation.rst +++ b/docs/source/overview/installation.rst @@ -12,4 +12,4 @@ There are a number of optional dependencies that must be installed to use certai - `flash-attn `_ for flash attention and certain other fused operations. - `torchao `_ for float8 training (see :mod:`olmo_core.float8`). -- `grouped_gemm `_ for mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`). +- `grouped_gemm `_ for dropless mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`). From 45fe06c95da21773f7e1bd534f8a454f79e7d284 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 13:00:15 -0800 Subject: [PATCH 228/230] try with host-networking --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 683b4aafd..75f0e0223 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -182,6 +182,7 @@ jobs: preemptible: true resources: gpuCount: ${{ matrix.task.gpus }} + hostNetworking: true constraints: cluster: # H100 clusters From 9e656d71d60a8e68b5c2e60104b7aea5a68e9252 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 13:04:31 -0800 Subject: [PATCH 229/230] use multiple GPUs for MoE test --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 75f0e0223..5cc5e0806 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -128,7 +128,7 @@ jobs: - name: Test MoE (GPU) image: olmo-core-tch260cu124 - gpus: 1 + gpus: 2 run: | pytest -v --color=yes --durations=3 -m gpu \ src/test/nn/moe* From 06d8140b4e344ea1f58a652505d6ecfae9f01a01 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 10 Feb 2025 13:11:38 -0800 Subject: [PATCH 230/230] update long context defaults --- src/scripts/train/OLMo2-7B-long-context.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scripts/train/OLMo2-7B-long-context.py b/src/scripts/train/OLMo2-7B-long-context.py index 7e2ea5aaf..f3b79edfc 100644 --- a/src/scripts/train/OLMo2-7B-long-context.py +++ b/src/scripts/train/OLMo2-7B-long-context.py @@ -44,6 +44,8 @@ def build_train_module_config(common: CommonComponents) -> TransformerTrainModul fused=True, ), compile_model=True, + compile_loss=True, + z_loss_multiplier=1e-5, dp_config=TransformerDataParallelConfig( name=DataParallelType.fsdp, param_dtype=DType.bfloat16, @@ -55,9 +57,7 @@ def build_train_module_config(common: CommonComponents) -> TransformerTrainModul loss_parallel=True, ), ac_config=TransformerActivationCheckpointingConfig(), - float8_config=Float8Config(enabled=True), - z_loss_multiplier=1e-5, - compile_loss=True, + float8_config=Float8Config(enabled=False), # TODO (epwalsh): broken with TP max_grad_norm=1.0, scheduler=CosWithWarmup(warmup_steps=2000), )