Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer sequence sharding #67

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions models/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Keep it human-readable, your future self will thank you!
- Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84)
- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97)
- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88)
- Add sequence sharding strategy for TransformerProcessor [#67](https://github.com/ecmwf/anemoi-core/pull/67)

## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design

Expand Down
136 changes: 136 additions & 0 deletions models/src/anemoi/models/distributed/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# nor does it submit to any jurisdiction.


import logging
from typing import Optional

import torch
Expand All @@ -17,6 +18,8 @@

from anemoi.models.distributed.utils import get_memory_format

LOGGER = logging.getLogger(__name__)


def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor:
"""Apply all_to_all along the head dimension.
Expand Down Expand Up @@ -82,6 +85,72 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N
return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format)


def _halo_comm(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bool = False) -> Tensor:
"""Exchange halo regions between neighboring ranks.

Expected format is (batch_size, halo_size + sequence_length + halo_size, channels).

Parameters
----------
input_ : Tensor
Input tensor
halo_size : int
Halo size (left, right)
mgroup : ProcessGroup
Model communication group
bwd : bool
Flag to indicate if backward pass

Returns
-------
Tensor
Tensor with halo regions from neighboring ranks
"""
end = input_.shape[-2]

left_halo_slice = slice(0, halo_size)
right_halo_slice = slice(end - halo_size, end)
left_send_slice = slice(halo_size, 2 * halo_size)
right_send_slice = slice(end - 2 * halo_size, end - halo_size)

if bwd: # reverse halo exchange direction for gradient accumulation
left_halo_slice, left_send_slice = left_send_slice, left_halo_slice
right_halo_slice, right_send_slice = right_send_slice, right_halo_slice

left_send = input_[:, left_send_slice, :]
right_send = input_[:, right_send_slice, :]

# setup neighbor ranks and tensor lists for all_to_all communication
group_rank = dist.get_rank(mgroup)
group_size = dist.get_world_size(mgroup)
left_rank = group_rank - 1 if group_rank > 0 else None
right_rank = group_rank + 1 if group_rank < group_size - 1 else None

input_list = [torch.empty(0, device=input_.device) for _ in range(group_size)]
if left_rank is not None:
input_list[left_rank] = left_send
if right_rank is not None:
input_list[right_rank] = right_send
output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list]

dist.all_to_all(output_list, input_list, group=mgroup)

if bwd: # add gradient contributions to halo regions and zero out send regions
if left_rank is not None:
input_[:, left_send_slice, :] = 0
input_[:, left_halo_slice, :] += output_list[left_rank]
if right_rank is not None:
input_[:, right_send_slice, :] = 0
input_[:, right_halo_slice, :] += output_list[right_rank]
else: # add halo regions to input tensor
if left_rank is not None:
input_[:, left_halo_slice, :] = output_list[left_rank]
if right_rank is not None:
input_[:, right_halo_slice, :] = output_list[right_rank]

return input_


def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor:
"""Sync tensor.

Expand Down Expand Up @@ -130,6 +199,49 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor
return _SplitSequenceParallelSection.apply(input_, shapes, mgroup)


def add_halos(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> Tensor:
halo_size_left = halo_size if dist.get_rank(mgroup) != 0 else 0
halo_size_right = halo_size if dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1 else 0

return (
torch.nn.functional.pad(x, pad=(0, 0, halo_size_left, halo_size_right), mode="constant", value=0),
halo_size_left,
halo_size_right,
)


def remove_halos(x: Tensor, halo_size_left: int, halo_size_right: int) -> Tensor:
return x[:, :, halo_size_left : x.shape[-2] - halo_size_right, :]


def halo_exchange(x: Tensor, halo_size: int, mgroup: Optional[ProcessGroup] = None) -> Tensor:
"""Exchange halo regions between ranks,

Parameters
----------
x : Tensor
Input tensor
halo_size : int
Halo size (left, right)
mgroup : ProcessGroup
Model communication group

Returns
-------
Tensor, int, int
Tensor appended with halo regions from neighboring ranks, left halo size, right halo size
"""
if mgroup is None or dist.get_world_size(mgroup) == 1:
return x, 0, 0

# pad tensor with halo regions
out, halo_size_left, halo_size_right = add_halos(x, halo_size, mgroup)

out = _HaloExchangeParallelSection.apply(out, halo_size, mgroup)

return out, halo_size_left, halo_size_right


class _SplitHeadsParallelSection(torch.autograd.Function):
"""Sync the input from parallel section."""

Expand Down Expand Up @@ -172,3 +284,27 @@ def backward(ctx, grad_output):
None,
)
return grad_output, None, None


class _HaloExchangeParallelSection(torch.autograd.Function):
"""Exchange halo regions between ranks."""

@staticmethod
def forward(ctx, input_, halo_size_, mgroup_):
ctx.halo_size = halo_size_
ctx.mgroup = mgroup_

if mgroup_:
return _halo_comm(input_, halo_size_, mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.mgroup:
return (
_halo_comm(grad_output, ctx.halo_size, ctx.mgroup, bwd=True),
None,
None,
)

return grad_output, None, None
74 changes: 67 additions & 7 deletions models/src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
else:
_FLASH_ATTENTION_AVAILABLE = True


from anemoi.models.distributed.transformer import halo_exchange
from anemoi.models.distributed.transformer import remove_halos
from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

Expand All @@ -42,6 +45,7 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
):
super().__init__()

Expand All @@ -55,24 +59,58 @@ def __init__(
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal
self.shard_strategy = shard_strategy

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

if shard_strategy not in ["shard_heads", "shard_sequence"]:
raise ValueError(f"Invalid shard_strategy: {shard_strategy}")

if shard_strategy == "shard_sequence": # remove this after PR #47 is merged (sliding window support)
assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy"

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def forward(
def get_qkv_shard_sequence(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query, key, value = self.lin_qkv(x).chunk(3, -1)
assert (
shapes[-1][0] // 2 >= self.window_size[0]
), f"Sharded sequence length ({shapes[-1][0]}) must be at least twice the window size (2*{self.window_size[0]})"

# unpack grid dimension first to allow for halo exchange
x = einops.rearrange(
x,
"(batch grid) channels -> batch grid channels",
batch=batch_size,
)

if model_comm_group:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded accross GPUs"
# communicate halos (adds halos to x)
x_plus_halos, halo_size_left, halo_size_right = halo_exchange(
x, halo_size=self.window_size[0], mgroup=model_comm_group
)

query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1)

query, key, value = (
einops.rearrange(
t,
"batch grid (heads vars) -> batch heads grid vars",
heads=self.num_heads,
)
for t in (query, key, value)
)

return query, key, value, halo_size_left, halo_size_right

def get_qkv_shard_heads(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query, key, value = self.lin_qkv(x).chunk(3, -1)

query, key, value = (
einops.rearrange(
Expand All @@ -87,6 +125,24 @@ def forward(
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)

return query, key, value

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
if model_comm_group:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded accross GPUs"

if self.shard_strategy == "shard_sequence":
query, key, value, halo_size_left, halo_size_right = self.get_qkv_shard_sequence(
x, shapes, batch_size, model_comm_group
)
if self.shard_strategy == "shard_heads":
query, key, value = self.get_qkv_shard_heads(x, shapes, batch_size, model_comm_group)

dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
Expand All @@ -104,7 +160,11 @@ def forward(
dropout_p=dropout_p,
) # expects (batch heads grid variable) format

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
if self.shard_strategy == "shard_sequence":
out = remove_halos(out, halo_size_left, halo_size_right)
if self.shard_strategy == "shard_heads":
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)

out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)
Expand Down
2 changes: 2 additions & 0 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
activation: str,
window_size: int,
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
):
super().__init__()

Expand All @@ -87,6 +88,7 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

self.mlp = nn.Sequential(
Expand Down
4 changes: 4 additions & 0 deletions models/src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
) -> None:
"""Initialize TransformerProcessor.

Expand All @@ -92,6 +93,8 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float
Dropout probability used for multi-head self attention, default 0.0
shard_strategy: str
Strategy for sharding either "shard_sequence" or "shard_heads", by default "shard_sequence"
"""
super().__init__(num_channels=num_channels, num_layers=num_layers)

Expand All @@ -103,6 +106,7 @@ def __init__(
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

def forward(
Expand Down
4 changes: 4 additions & 0 deletions models/src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
shard_strategy: str = "shard_sequence",
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -117,6 +118,8 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float, optional
Dropout probability used for multi-head self attention, default 0.0
shard_strategy: str, optional
Strategy for sharding either "shard_sequence" or "shard_heads", by default "shard_sequence"
"""
super().__init__(
num_channels=num_channels,
Expand All @@ -138,6 +141,7 @@ def __init__(
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

self.offload_layers(cpu_offload)
Expand Down
1 change: 1 addition & 0 deletions training/src/anemoi/training/config/model/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ processor:
num_heads: 16 # GraphTransformer or Transformer only
window_size: 512
dropout_p: 0.0 # GraphTransformer
shard_strategy: shard_sequence # Options: shard_sequence, shard_heads

encoder:
_target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper
Expand Down
Loading