Skip to content

Commit e73cff8

Browse files
Merge branch 'Project-MONAI:dev' into 4980-get-wsi-at-mpp
2 parents 1c5a26c + 4986d7f commit e73cff8

File tree

14 files changed

+1378
-55
lines changed

14 files changed

+1378
-55
lines changed

docs/source/networks.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ Blocks
109109
.. autoclass:: SABlock
110110
:members:
111111

112+
`CABlock Block`
113+
~~~~~~~~~~~~~~~
114+
.. autoclass:: CABlock
115+
:members:
116+
117+
`FeedForward Block`
118+
~~~~~~~~~~~~~~~~~~~
119+
.. autoclass:: FeedForward
120+
:members:
121+
112122
`Squeeze-and-Excitation`
113123
~~~~~~~~~~~~~~~~~~~~~~~~
114124
.. autoclass:: ChannelSELayer
@@ -173,6 +183,16 @@ Blocks
173183
.. autoclass:: Subpixelupsample
174184
.. autoclass:: SubpixelUpSample
175185

186+
`Downsampling`
187+
~~~~~~~~~~~~~~
188+
.. autoclass:: DownSample
189+
:members:
190+
.. autoclass:: Downsample
191+
.. autoclass:: SubpixelDownsample
192+
:members:
193+
.. autoclass:: Subpixeldownsample
194+
.. autoclass:: SubpixelDownSample
195+
176196
`Registration Residual Conv Block`
177197
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
178198
.. autoclass:: RegistrationResidualConvBlock
@@ -625,6 +645,11 @@ Nets
625645
.. autoclass:: ViT
626646
:members:
627647

648+
`Restormer`
649+
~~~~~~~~~~~
650+
.. autoclass:: restormer
651+
:members:
652+
628653
`ViTAutoEnc`
629654
~~~~~~~~~~~~
630655
.. autoclass:: ViTAutoEnc

monai/networks/blocks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish
1616
from .aspp import SimpleASPP
1717
from .backbone_fpn_utils import BackboneWithFPN
18+
from .cablock import CABlock, FeedForward
1819
from .convolutions import Convolution, ResidualUnit
1920
from .crf import CRF
2021
from .crossattention import CrossAttentionBlock
2122
from .denseblock import ConvDenseBlock, DenseBlock
2223
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
23-
from .downsample import MaxAvgPool
24+
from .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample
2425
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
2526
from .encoder import BaseEncoder
2627
from .fcn import FCN, GCN, MCFCN, Refine

monai/networks/blocks/cablock.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from __future__ import annotations
12+
13+
from typing import cast
14+
15+
import torch
16+
import torch.nn as nn
17+
import torch.nn.functional as F
18+
19+
from monai.networks.blocks.convolutions import Convolution
20+
from monai.utils import optional_import
21+
22+
rearrange, _ = optional_import("einops", name="rearrange")
23+
24+
__all__ = ["FeedForward", "CABlock"]
25+
26+
27+
class FeedForward(nn.Module):
28+
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
29+
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.
30+
31+
Args:
32+
spatial_dims: Number of spatial dimensions (2D or 3D)
33+
dim: Number of input channels
34+
ffn_expansion_factor: Factor to expand hidden features dimension
35+
bias: Whether to use bias in convolution layers
36+
"""
37+
38+
def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):
39+
super().__init__()
40+
hidden_features = int(dim * ffn_expansion_factor)
41+
42+
self.project_in = Convolution(
43+
spatial_dims=spatial_dims,
44+
in_channels=dim,
45+
out_channels=hidden_features * 2,
46+
kernel_size=1,
47+
bias=bias,
48+
conv_only=True,
49+
)
50+
51+
self.dwconv = Convolution(
52+
spatial_dims=spatial_dims,
53+
in_channels=hidden_features * 2,
54+
out_channels=hidden_features * 2,
55+
kernel_size=3,
56+
strides=1,
57+
padding=1,
58+
groups=hidden_features * 2,
59+
bias=bias,
60+
conv_only=True,
61+
)
62+
63+
self.project_out = Convolution(
64+
spatial_dims=spatial_dims,
65+
in_channels=hidden_features,
66+
out_channels=dim,
67+
kernel_size=1,
68+
bias=bias,
69+
conv_only=True,
70+
)
71+
72+
def forward(self, x: torch.Tensor) -> torch.Tensor:
73+
x = self.project_in(x)
74+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
75+
return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2))
76+
77+
78+
class CABlock(nn.Module):
79+
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
80+
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
81+
convolutions for local mixing before attention, achieving linear complexity vs quadratic
82+
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>
83+
84+
Args:
85+
spatial_dims: Number of spatial dimensions (2D or 3D)
86+
dim: Number of input channels
87+
num_heads: Number of attention heads
88+
bias: Whether to use bias in convolution layers
89+
flash_attention: Whether to use flash attention optimization. Defaults to False.
90+
91+
Raises:
92+
ValueError: If flash attention is not available in current PyTorch version
93+
ValueError: If spatial_dims is greater than 3
94+
"""
95+
96+
def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
97+
super().__init__()
98+
if flash_attention and not hasattr(F, "scaled_dot_product_attention"):
99+
raise ValueError("Flash attention not available")
100+
if spatial_dims > 3:
101+
raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}")
102+
self.spatial_dims = spatial_dims
103+
self.num_heads = num_heads
104+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
105+
self.flash_attention = flash_attention
106+
107+
self.qkv = Convolution(
108+
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True
109+
)
110+
111+
self.qkv_dwconv = Convolution(
112+
spatial_dims=spatial_dims,
113+
in_channels=dim * 3,
114+
out_channels=dim * 3,
115+
kernel_size=3,
116+
strides=1,
117+
padding=1,
118+
groups=dim * 3,
119+
bias=bias,
120+
conv_only=True,
121+
)
122+
123+
self.project_out = Convolution(
124+
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True
125+
)
126+
127+
self._attention_fn = self._get_attention_fn()
128+
129+
def _get_attention_fn(self):
130+
if self.flash_attention:
131+
return self._flash_attention
132+
return self._normal_attention
133+
134+
def _flash_attention(self, q, k, v):
135+
"""Flash attention implementation using scaled dot-product attention."""
136+
scale = float(self.temperature.mean())
137+
out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False)
138+
return out
139+
140+
def _normal_attention(self, q, k, v):
141+
"""Attention matrix multiplication with depth-wise convolutions."""
142+
attn = (q @ k.transpose(-2, -1)) * self.temperature
143+
attn = attn.softmax(dim=-1)
144+
return attn @ v
145+
146+
def forward(self, x: torch.Tensor) -> torch.Tensor:
147+
"""Forward pass for MDTA attention.
148+
1. Apply depth-wise convolutions to Q, K, V
149+
2. Reshape Q, K, V for multi-head attention
150+
3. Compute attention matrix using flash or normal attention
151+
4. Reshape and project out attention output"""
152+
spatial_dims = x.shape[2:]
153+
154+
# Project and mix
155+
qkv = self.qkv_dwconv(self.qkv(x))
156+
q, k, v = qkv.chunk(3, dim=1)
157+
158+
# Select attention
159+
if self.spatial_dims == 2:
160+
qkv_to_multihead = "b (head c) h w -> b head c (h w)"
161+
multihead_to_qkv = "b head c (h w) -> b (head c) h w"
162+
else: # dims == 3
163+
qkv_to_multihead = "b (head c) d h w -> b head c (d h w)"
164+
multihead_to_qkv = "b head c (d h w) -> b (head c) d h w"
165+
166+
# Reconstruct and project feature map
167+
q = rearrange(q, qkv_to_multihead, head=self.num_heads)
168+
k = rearrange(k, qkv_to_multihead, head=self.num_heads)
169+
v = rearrange(v, qkv_to_multihead, head=self.num_heads)
170+
171+
q = torch.nn.functional.normalize(q, dim=-1)
172+
k = torch.nn.functional.normalize(k, dim=-1)
173+
174+
out = self._attention_fn(q, k, v)
175+
out = rearrange(
176+
out,
177+
multihead_to_qkv,
178+
head=self.num_heads,
179+
**dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)),
180+
)
181+
182+
return cast(torch.Tensor, self.project_out(out))

0 commit comments

Comments
 (0)