Skip to content

Commit c828071

Browse files
committed
test helpers
1 parent e1b6910 commit c828071

File tree

3 files changed

+74
-3
lines changed

3 files changed

+74
-3
lines changed

src/model_constructor/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __repr_changed_args__(self) -> list[str]:
103103

104104
def print_cfg(self) -> None:
105105
"""Print full config"""
106-
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
106+
print(self.__repr__())
107107

108108
def print_changed(self) -> None:
109109
"""Print changed fields."""

tests/test_helpers.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from functools import partial
2+
3+
from pytest import CaptureFixture
4+
from torch import nn
5+
6+
from model_constructor.helpers import Cfg, instantiate_module, is_module
7+
8+
9+
class Cfg2(Cfg):
10+
int_value: int = 10
11+
12+
13+
def test_is_module():
14+
"""test is_module"""
15+
assert not is_module("some string")
16+
assert is_module(nn.Module)
17+
assert is_module(nn.ReLU)
18+
assert not is_module(nn)
19+
assert is_module(partial(nn.ReLU, inplace=True))
20+
assert not is_module(partial(int, "10"))
21+
22+
23+
def test_instantiate_module():
24+
"""test instantiate_module"""
25+
mod = instantiate_module("ReLU")
26+
assert mod is nn.ReLU
27+
mod = instantiate_module("nn.Tanh")
28+
assert mod is nn.Tanh
29+
mod = instantiate_module("torch.nn.SELU")
30+
assert mod is nn.SELU
31+
# wrong name
32+
try:
33+
mod = instantiate_module("wrong_name")
34+
except ImportError as err:
35+
assert str(err) == "Module wrong_name not found at torch.nn"
36+
# wrong module
37+
try:
38+
mod = instantiate_module("wrong_module.some_name")
39+
except ImportError as err:
40+
assert str(err) == "Module wrong_module not found"
41+
# not nn.Module
42+
try:
43+
mod = instantiate_module("model_constructor.helpers.instantiate_module")
44+
except ImportError as err:
45+
assert str(err) == "Module instantiate_module is not a nn.Module"
46+
47+
48+
def test_cfg_repr_print(capsys: CaptureFixture[str]):
49+
"""test repr and print results"""
50+
cfg = Cfg()
51+
repr_res = cfg.__repr__()
52+
assert repr_res == "Cfg(\n )"
53+
cfg.print_changed()
54+
out = capsys.readouterr().out
55+
assert out == "Nothing changed\n"
56+
cfg.name = "cfg_name"
57+
repr_res = cfg.__repr__()
58+
assert repr_res == "Cfg(\n name='cfg_name')"
59+
cfg.print_cfg()
60+
out = capsys.readouterr().out
61+
assert out == "Cfg(\n name='cfg_name')\n"
62+
# changed fields. default - name is not in changed
63+
cfg = Cfg2(name="cfg_name")
64+
cfg.print_changed()
65+
out = capsys.readouterr().out
66+
assert out == "Nothing changed\n"
67+
assert "name" in cfg.model_fields_set
68+
cfg = Cfg2(int_value=0)
69+
cfg.print_changed()
70+
out = capsys.readouterr().out
71+
assert out == "Changed fields:\n int_value: 0\n"

tests/test_models_universal_blocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def test_mc(model_constructor: type[ModelConstructor], act_fn: type[nn.Module]):
3737
assert pred.shape == torch.Size([bs_test, 1000])
3838

3939

40-
def test_stem_bnend():
40+
def test_stem_bn_end():
4141
"""test stem"""
42-
mc = ModelConstructor()
42+
mc = XResNet()
4343
assert mc.stem_bn_end == False
4444
mc.stem_bn_end = True
4545
stem = mc.stem

0 commit comments

Comments
 (0)