Skip to content

Commit e3c89ed

Browse files
authored
Merge pull request #103 from ayasyrev/tests
Tests
2 parents b998f83 + c828071 commit e3c89ed

File tree

7 files changed

+142
-6
lines changed

7 files changed

+142
-6
lines changed

src/model_constructor/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
id_layers: ListStrMod = []
6767
if (
6868
stride != 1 and pool is not None
69-
): # if pool - reduce by pool else stride 2 art id_conv
69+
): # if pool - reduce by pool else stride 2 at id_conv
7070
id_layers.append(("pool", pool()))
7171
if in_channels != out_channels or (stride != 1 and pool is None):
7272
id_layers.append(

src/model_constructor/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __repr_changed_args__(self) -> list[str]:
103103

104104
def print_cfg(self) -> None:
105105
"""Print full config"""
106-
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
106+
print(self.__repr__())
107107

108108
def print_changed(self) -> None:
109109
"""Print changed fields."""

src/model_constructor/model_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def set_modules( # pylint: disable=no-self-argument
6767
return value
6868
if isinstance(value, str):
6969
return instantiate_module(value)
70-
raise ValueError(f"{info.field_name} must be str or nn.Module")
70+
# raise ValueError(f"{info.field_name} must be str or nn.Module")
7171

7272
@field_validator("se", "sa")
7373
def set_se( # pylint: disable=no-self-argument
@@ -77,7 +77,7 @@ def set_se( # pylint: disable=no-self-argument
7777
return DEFAULT_SE_SA[info.field_name]
7878
if is_module(value):
7979
return value
80-
raise ValueError(f"{info.field_name} must be bool or nn.Module")
80+
# raise ValueError(f"{info.field_name} must be bool or nn.Module") # no need - check at init
8181

8282
@field_validator("se_module", "se_reduction") # pragma: no cover
8383
def deprecation_warning( # pylint: disable=no-self-argument

tests/test_helpers.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from functools import partial
2+
3+
from pytest import CaptureFixture
4+
from torch import nn
5+
6+
from model_constructor.helpers import Cfg, instantiate_module, is_module
7+
8+
9+
class Cfg2(Cfg):
10+
int_value: int = 10
11+
12+
13+
def test_is_module():
14+
"""test is_module"""
15+
assert not is_module("some string")
16+
assert is_module(nn.Module)
17+
assert is_module(nn.ReLU)
18+
assert not is_module(nn)
19+
assert is_module(partial(nn.ReLU, inplace=True))
20+
assert not is_module(partial(int, "10"))
21+
22+
23+
def test_instantiate_module():
24+
"""test instantiate_module"""
25+
mod = instantiate_module("ReLU")
26+
assert mod is nn.ReLU
27+
mod = instantiate_module("nn.Tanh")
28+
assert mod is nn.Tanh
29+
mod = instantiate_module("torch.nn.SELU")
30+
assert mod is nn.SELU
31+
# wrong name
32+
try:
33+
mod = instantiate_module("wrong_name")
34+
except ImportError as err:
35+
assert str(err) == "Module wrong_name not found at torch.nn"
36+
# wrong module
37+
try:
38+
mod = instantiate_module("wrong_module.some_name")
39+
except ImportError as err:
40+
assert str(err) == "Module wrong_module not found"
41+
# not nn.Module
42+
try:
43+
mod = instantiate_module("model_constructor.helpers.instantiate_module")
44+
except ImportError as err:
45+
assert str(err) == "Module instantiate_module is not a nn.Module"
46+
47+
48+
def test_cfg_repr_print(capsys: CaptureFixture[str]):
49+
"""test repr and print results"""
50+
cfg = Cfg()
51+
repr_res = cfg.__repr__()
52+
assert repr_res == "Cfg(\n )"
53+
cfg.print_changed()
54+
out = capsys.readouterr().out
55+
assert out == "Nothing changed\n"
56+
cfg.name = "cfg_name"
57+
repr_res = cfg.__repr__()
58+
assert repr_res == "Cfg(\n name='cfg_name')"
59+
cfg.print_cfg()
60+
out = capsys.readouterr().out
61+
assert out == "Cfg(\n name='cfg_name')\n"
62+
# changed fields. default - name is not in changed
63+
cfg = Cfg2(name="cfg_name")
64+
cfg.print_changed()
65+
out = capsys.readouterr().out
66+
assert out == "Nothing changed\n"
67+
assert "name" in cfg.model_fields_set
68+
cfg = Cfg2(int_value=0)
69+
cfg.print_changed()
70+
out = capsys.readouterr().out
71+
assert out == "Changed fields:\n int_value: 0\n"

tests/test_mc.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44

5-
from model_constructor.blocks import BottleneckBlock
5+
from model_constructor.blocks import BasicBlock, BottleneckBlock
66
from model_constructor.layers import SEModule, SEModuleConv, SimpleSelfAttention
7-
from model_constructor.model_constructor import ModelConstructor
7+
from model_constructor.model_constructor import ModelCfg, ModelConstructor
88

99
bs_test = 4
1010
in_chans = 3
@@ -94,3 +94,46 @@ def test_MC_bottleneck():
9494
assert model.body.l_0.bl_0.convs.conv_0.conv.in_channels == 64
9595
assert model.body.l_0.bl_0.convs.conv_0.conv.out_channels == 128
9696
assert model.body.l_0.bl_1.convs.conv_0.conv.in_channels == 256
97+
98+
99+
def test_ModelCfg():
100+
"""test ModelCfg"""
101+
# default - just create config with custom name
102+
cfg = ModelCfg(name="custom_name")
103+
repr_str = cfg.__repr__()
104+
assert repr_str.startswith("custom_name")
105+
# initiate from string
106+
cfg = ModelCfg(act_fn="torch.nn.Mish")
107+
assert cfg.act_fn is torch.nn.Mish
108+
# wrong name
109+
try:
110+
cfg = ModelCfg(act_fn="wrong_name")
111+
except ImportError as err:
112+
assert str(err) == "Module wrong_name not found at torch.nn"
113+
cfg = ModelCfg(act_fn="nn.Tanh")
114+
assert cfg.act_fn is torch.nn.Tanh
115+
cfg = ModelCfg(block="model_constructor.blocks.BottleneckBlock")
116+
assert cfg.block is BottleneckBlock
117+
118+
119+
def test_create_model_class_methods():
120+
"""test class methods ModelConstructor"""
121+
# create model
122+
model = ModelConstructor.create_model(act_fn="Mish", num_classes=10)
123+
assert str(model.body.l_0.bl_0.convs.conv_0.act_fn) == "Mish(inplace=True)"
124+
pred = model(xb)
125+
assert pred.shape == torch.Size([bs_test, 10])
126+
# from cfg
127+
cfg = ModelCfg(block=BottleneckBlock, num_classes=10)
128+
mc = ModelConstructor.from_cfg(cfg)
129+
model = mc()
130+
assert isinstance(model.body.l_0.bl_0, BottleneckBlock)
131+
pred = model(xb)
132+
assert pred.shape == torch.Size([bs_test, 10])
133+
134+
cfg.block = BasicBlock
135+
cfg.num_classes = 2
136+
model = ModelConstructor.create_model(cfg)
137+
assert isinstance(model.body.l_0.bl_0, BasicBlock)
138+
pred = model(xb)
139+
assert pred.shape == torch.Size([bs_test, 2])

tests/test_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,14 @@ def test_mc(model_constructor: type[ModelConstructor], act_fn: type[nn.Module]):
4040
model = mc()
4141
pred = model(xb)
4242
assert pred.shape == torch.Size([bs_test, 1000])
43+
44+
45+
def test_xresnet_stem():
46+
"""test xresnet stem"""
47+
mc = XResNet()
48+
assert mc.stem_bn_end == False
49+
mc.stem_bn_end = True
50+
stem = mc.stem
51+
assert isinstance(stem[-1], nn.BatchNorm2d)
52+
stem_out = stem(xb)
53+
assert stem_out.shape == torch.Size([bs_test, 64, 4, 4])

tests/test_models_universal_blocks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,14 @@ def test_mc(model_constructor: type[ModelConstructor], act_fn: type[nn.Module]):
3535
model = mc()
3636
pred = model(xb)
3737
assert pred.shape == torch.Size([bs_test, 1000])
38+
39+
40+
def test_stem_bn_end():
41+
"""test stem"""
42+
mc = XResNet()
43+
assert mc.stem_bn_end == False
44+
mc.stem_bn_end = True
45+
stem = mc.stem
46+
assert isinstance(stem[-1], nn.BatchNorm2d)
47+
stem_out = stem(xb)
48+
assert stem_out.shape == torch.Size([bs_test, 64, 4, 4])

0 commit comments

Comments
 (0)