|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +from typing import cast |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | +import torch.nn.functional as F |
| 18 | + |
| 19 | +from monai.networks.blocks.convolutions import Convolution |
| 20 | +from monai.utils import optional_import |
| 21 | + |
| 22 | +rearrange, _ = optional_import("einops", name="rearrange") |
| 23 | + |
| 24 | +__all__ = ["FeedForward", "CABlock"] |
| 25 | + |
| 26 | + |
| 27 | +class FeedForward(nn.Module): |
| 28 | + """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. |
| 29 | + Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection. |
| 30 | +
|
| 31 | + Args: |
| 32 | + spatial_dims: Number of spatial dimensions (2D or 3D) |
| 33 | + dim: Number of input channels |
| 34 | + ffn_expansion_factor: Factor to expand hidden features dimension |
| 35 | + bias: Whether to use bias in convolution layers |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool): |
| 39 | + super().__init__() |
| 40 | + hidden_features = int(dim * ffn_expansion_factor) |
| 41 | + |
| 42 | + self.project_in = Convolution( |
| 43 | + spatial_dims=spatial_dims, |
| 44 | + in_channels=dim, |
| 45 | + out_channels=hidden_features * 2, |
| 46 | + kernel_size=1, |
| 47 | + bias=bias, |
| 48 | + conv_only=True, |
| 49 | + ) |
| 50 | + |
| 51 | + self.dwconv = Convolution( |
| 52 | + spatial_dims=spatial_dims, |
| 53 | + in_channels=hidden_features * 2, |
| 54 | + out_channels=hidden_features * 2, |
| 55 | + kernel_size=3, |
| 56 | + strides=1, |
| 57 | + padding=1, |
| 58 | + groups=hidden_features * 2, |
| 59 | + bias=bias, |
| 60 | + conv_only=True, |
| 61 | + ) |
| 62 | + |
| 63 | + self.project_out = Convolution( |
| 64 | + spatial_dims=spatial_dims, |
| 65 | + in_channels=hidden_features, |
| 66 | + out_channels=dim, |
| 67 | + kernel_size=1, |
| 68 | + bias=bias, |
| 69 | + conv_only=True, |
| 70 | + ) |
| 71 | + |
| 72 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 73 | + x = self.project_in(x) |
| 74 | + x1, x2 = self.dwconv(x).chunk(2, dim=1) |
| 75 | + return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2)) |
| 76 | + |
| 77 | + |
| 78 | +class CABlock(nn.Module): |
| 79 | + """Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention |
| 80 | + by operating on feature channels instead of spatial dimensions. Incorporates depth-wise |
| 81 | + convolutions for local mixing before attention, achieving linear complexity vs quadratic |
| 82 | + in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881> |
| 83 | +
|
| 84 | + Args: |
| 85 | + spatial_dims: Number of spatial dimensions (2D or 3D) |
| 86 | + dim: Number of input channels |
| 87 | + num_heads: Number of attention heads |
| 88 | + bias: Whether to use bias in convolution layers |
| 89 | + flash_attention: Whether to use flash attention optimization. Defaults to False. |
| 90 | +
|
| 91 | + Raises: |
| 92 | + ValueError: If flash attention is not available in current PyTorch version |
| 93 | + ValueError: If spatial_dims is greater than 3 |
| 94 | + """ |
| 95 | + |
| 96 | + def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): |
| 97 | + super().__init__() |
| 98 | + if flash_attention and not hasattr(F, "scaled_dot_product_attention"): |
| 99 | + raise ValueError("Flash attention not available") |
| 100 | + if spatial_dims > 3: |
| 101 | + raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}") |
| 102 | + self.spatial_dims = spatial_dims |
| 103 | + self.num_heads = num_heads |
| 104 | + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) |
| 105 | + self.flash_attention = flash_attention |
| 106 | + |
| 107 | + self.qkv = Convolution( |
| 108 | + spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True |
| 109 | + ) |
| 110 | + |
| 111 | + self.qkv_dwconv = Convolution( |
| 112 | + spatial_dims=spatial_dims, |
| 113 | + in_channels=dim * 3, |
| 114 | + out_channels=dim * 3, |
| 115 | + kernel_size=3, |
| 116 | + strides=1, |
| 117 | + padding=1, |
| 118 | + groups=dim * 3, |
| 119 | + bias=bias, |
| 120 | + conv_only=True, |
| 121 | + ) |
| 122 | + |
| 123 | + self.project_out = Convolution( |
| 124 | + spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True |
| 125 | + ) |
| 126 | + |
| 127 | + self._attention_fn = self._get_attention_fn() |
| 128 | + |
| 129 | + def _get_attention_fn(self): |
| 130 | + if self.flash_attention: |
| 131 | + return self._flash_attention |
| 132 | + return self._normal_attention |
| 133 | + |
| 134 | + def _flash_attention(self, q, k, v): |
| 135 | + """Flash attention implementation using scaled dot-product attention.""" |
| 136 | + scale = float(self.temperature.mean()) |
| 137 | + out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False) |
| 138 | + return out |
| 139 | + |
| 140 | + def _normal_attention(self, q, k, v): |
| 141 | + """Attention matrix multiplication with depth-wise convolutions.""" |
| 142 | + attn = (q @ k.transpose(-2, -1)) * self.temperature |
| 143 | + attn = attn.softmax(dim=-1) |
| 144 | + return attn @ v |
| 145 | + |
| 146 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 147 | + """Forward pass for MDTA attention. |
| 148 | + 1. Apply depth-wise convolutions to Q, K, V |
| 149 | + 2. Reshape Q, K, V for multi-head attention |
| 150 | + 3. Compute attention matrix using flash or normal attention |
| 151 | + 4. Reshape and project out attention output""" |
| 152 | + spatial_dims = x.shape[2:] |
| 153 | + |
| 154 | + # Project and mix |
| 155 | + qkv = self.qkv_dwconv(self.qkv(x)) |
| 156 | + q, k, v = qkv.chunk(3, dim=1) |
| 157 | + |
| 158 | + # Select attention |
| 159 | + if self.spatial_dims == 2: |
| 160 | + qkv_to_multihead = "b (head c) h w -> b head c (h w)" |
| 161 | + multihead_to_qkv = "b head c (h w) -> b (head c) h w" |
| 162 | + else: # dims == 3 |
| 163 | + qkv_to_multihead = "b (head c) d h w -> b head c (d h w)" |
| 164 | + multihead_to_qkv = "b head c (d h w) -> b (head c) d h w" |
| 165 | + |
| 166 | + # Reconstruct and project feature map |
| 167 | + q = rearrange(q, qkv_to_multihead, head=self.num_heads) |
| 168 | + k = rearrange(k, qkv_to_multihead, head=self.num_heads) |
| 169 | + v = rearrange(v, qkv_to_multihead, head=self.num_heads) |
| 170 | + |
| 171 | + q = torch.nn.functional.normalize(q, dim=-1) |
| 172 | + k = torch.nn.functional.normalize(k, dim=-1) |
| 173 | + |
| 174 | + out = self._attention_fn(q, k, v) |
| 175 | + out = rearrange( |
| 176 | + out, |
| 177 | + multihead_to_qkv, |
| 178 | + head=self.num_heads, |
| 179 | + **dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)), |
| 180 | + ) |
| 181 | + |
| 182 | + return cast(torch.Tensor, self.project_out(out)) |
0 commit comments