Skip to content

Commit 4fc47c3

Browse files
committed
Added test for AttentionUnet
1 parent ae450f1 commit 4fc47c3

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

mlspm/image/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
upscale2d_block_channels2: List[int] = [16, 16, 16],
7474
upscale2d_block_depth2: int = 2,
7575
split_conv_block_channels: List[int] = [16],
76-
split_conv_block_depth: List[int] = [3],
76+
split_conv_block_depth: int = 3,
7777
res_connections: bool = True,
7878
out_convs_channels: int | List[int] = 1,
7979
out_relus: bool | List[bool] = True,

tests/test_models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def test_GraphImgNet():
256256

257257
# fmt:on
258258

259+
259260
def test_ASDAFMNet():
260261

261262
import torch
@@ -272,3 +273,37 @@ def test_ASDAFMNet():
272273
assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128)
273274
assert ys[1].min() >= 0.0
274275
assert ys[2].min() >= 0.0
276+
277+
278+
def test_AttentionUnet():
279+
280+
import torch
281+
from mlspm.image.models import AttentionUNet
282+
283+
torch.manual_seed(0)
284+
285+
device = "cpu"
286+
model = AttentionUNet(
287+
conv3d_in_channels=1,
288+
conv2d_in_channels=64,
289+
conv3d_out_channels=[80, 80, 128],
290+
n_in=2,
291+
n_out=3,
292+
merge_block_channels=[8],
293+
conv3d_block_channels=[8, 16, 32],
294+
conv2d_block_channels=[128],
295+
attention_channels= [16, 32, 24],
296+
pool_z_strides=[2, 1, 2],
297+
device=device
298+
)
299+
300+
x = [torch.rand((5, 1, 128, 128, 10)).to(device), torch.rand((5, 1, 128, 128, 10)).to(device)]
301+
ys, att = model(x)
302+
303+
assert len(ys) == 3
304+
assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128)
305+
306+
assert len(att) == 3
307+
assert att[0].shape == (5, 32, 32)
308+
assert att[1].shape == (5, 64, 64)
309+
assert att[2].shape == (5, 128, 128)

0 commit comments

Comments
 (0)