@@ -256,6 +256,7 @@ def test_GraphImgNet():
256256
257257 # fmt:on
258258
259+
259260def test_ASDAFMNet ():
260261
261262 import torch
@@ -272,3 +273,37 @@ def test_ASDAFMNet():
272273 assert ys [0 ].shape == ys [1 ].shape == ys [2 ].shape == (5 , 128 , 128 )
273274 assert ys [1 ].min () >= 0.0
274275 assert ys [2 ].min () >= 0.0
276+
277+
278+ def test_AttentionUnet ():
279+
280+ import torch
281+ from mlspm .image .models import AttentionUNet
282+
283+ torch .manual_seed (0 )
284+
285+ device = "cpu"
286+ model = AttentionUNet (
287+ conv3d_in_channels = 1 ,
288+ conv2d_in_channels = 64 ,
289+ conv3d_out_channels = [80 , 80 , 128 ],
290+ n_in = 2 ,
291+ n_out = 3 ,
292+ merge_block_channels = [8 ],
293+ conv3d_block_channels = [8 , 16 , 32 ],
294+ conv2d_block_channels = [128 ],
295+ attention_channels = [16 , 32 , 24 ],
296+ pool_z_strides = [2 , 1 , 2 ],
297+ device = device
298+ )
299+
300+ x = [torch .rand ((5 , 1 , 128 , 128 , 10 )).to (device ), torch .rand ((5 , 1 , 128 , 128 , 10 )).to (device )]
301+ ys , att = model (x )
302+
303+ assert len (ys ) == 3
304+ assert ys [0 ].shape == ys [1 ].shape == ys [2 ].shape == (5 , 128 , 128 )
305+
306+ assert len (att ) == 3
307+ assert att [0 ].shape == (5 , 32 , 32 )
308+ assert att [1 ].shape == (5 , 64 , 64 )
309+ assert att [2 ].shape == (5 , 128 , 128 )
0 commit comments