Skip to content

Commit b998f83

Browse files
authored
Merge pull request #100 from ayasyrev/obj_from_str
Obj from str
2 parents 8bfcd08 + ae4b4cf commit b998f83

File tree

2 files changed

+85
-28
lines changed

2 files changed

+85
-28
lines changed

src/model_constructor/helpers.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import importlib
12
from collections import OrderedDict
23
from functools import partial
3-
from typing import Iterable, Optional, Union
4-
from pydantic import BaseModel
4+
from typing import Any, Iterable, Optional, Union
55

6+
from pydantic import BaseModel
67
from torch import nn
78

8-
99
ListStrMod = list[tuple[str, nn.Module]]
1010
ModSeq = Union[nn.Module, nn.Sequential]
1111

@@ -25,6 +25,49 @@ def init_cnn(module: nn.Module) -> None:
2525
init_cnn(layer)
2626

2727

28+
def is_module(val: Any) -> bool:
29+
"""Check if val is a nn.Module or partial of nn.Module."""
30+
31+
to_check = val
32+
if isinstance(val, partial):
33+
to_check = val.func
34+
try:
35+
return issubclass(to_check, nn.Module)
36+
except TypeError:
37+
return False
38+
39+
40+
def instantiate_module(
41+
name: str,
42+
default_path: Optional[str] = None,
43+
) -> nn.Module:
44+
"""Instantiate model from name."""
45+
if default_path is None:
46+
path_list = name.rsplit(".", 1)
47+
if len(path_list) == 1:
48+
default_path = "torch.nn"
49+
name = path_list[0]
50+
else:
51+
if path_list[0] == "nn":
52+
default_path = "torch.nn"
53+
name = path_list[1]
54+
else:
55+
default_path = path_list[0]
56+
name = path_list[1]
57+
try:
58+
mod = importlib.import_module(default_path)
59+
except ImportError:
60+
raise ImportError(f"Module {default_path} not found")
61+
if hasattr(mod, name):
62+
module = getattr(mod, name)
63+
if is_module(module):
64+
return module
65+
else:
66+
raise ImportError(f"Module {name} is not a nn.Module")
67+
else:
68+
raise ImportError(f"Module {name} not found at {default_path}")
69+
70+
2871
class Cfg(BaseModel):
2972
"""Base class for config."""
3073

src/model_constructor/model_constructor.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from typing import Any, Callable, Optional, Union
44

55
from pydantic import field_validator
6+
from pydantic_core.core_schema import FieldValidationInfo
67
from torch import nn
78

89
from .blocks import BasicBlock, BottleneckBlock
9-
from .helpers import Cfg, ListStrMod, ModSeq, init_cnn, nn_seq
10+
from .helpers import (Cfg, ListStrMod, ModSeq, init_cnn, instantiate_module,
11+
is_module, nn_seq)
1012
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
1113

1214
__all__ = [
@@ -17,53 +19,65 @@
1719
]
1820

1921

22+
DEFAULT_SE_SA = {
23+
"se": SEModule,
24+
"sa": SimpleSelfAttention,
25+
}
26+
27+
28+
nnModule = Union[type[nn.Module], Callable[[], nn.Module], str]
29+
30+
2031
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
2132
"""Model constructor Config. As default - xresnet18"""
2233

2334
name: Optional[str] = None
2435
in_chans: int = 3
2536
num_classes: int = 1000
26-
block: type[nn.Module] = BasicBlock
27-
conv_layer: type[nn.Module] = ConvBnAct
37+
block: nnModule = BasicBlock
38+
conv_layer: nnModule = ConvBnAct
2839
block_sizes: list[int] = [64, 128, 256, 512]
2940
layers: list[int] = [2, 2, 2, 2]
30-
norm: type[nn.Module] = nn.BatchNorm2d
31-
act_fn: type[nn.Module] = nn.ReLU
32-
pool: Optional[Callable[[Any], nn.Module]] = None
41+
norm: nnModule = nn.BatchNorm2d
42+
act_fn: nnModule = nn.ReLU
43+
pool: Optional[nnModule] = None
3344
expansion: int = 1
3445
groups: int = 1
3546
dw: bool = False
3647
div_groups: Optional[int] = None
37-
sa: Union[bool, type[nn.Module]] = False
38-
se: Union[bool, type[nn.Module]] = False
48+
sa: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
49+
se: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
3950
se_module: Optional[bool] = None
4051
se_reduction: Optional[int] = None
4152
bn_1st: bool = True
4253
zero_bn: bool = True
4354
stem_stride_on: int = 0
4455
stem_sizes: list[int] = [64]
45-
stem_pool: Optional[Callable[[], nn.Module]] = partial(
56+
stem_pool: Optional[nnModule] = partial(
4657
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
4758
)
4859
stem_bn_end: bool = False
4960

50-
@field_validator("se")
61+
@field_validator("act_fn", "block", "conv_layer", "norm", "pool", "stem_pool")
62+
def set_modules( # pylint: disable=no-self-argument
63+
cls, value: Union[type[nn.Module], str], info: FieldValidationInfo,
64+
) -> Union[type[nn.Module], Callable[[], nn.Module]]:
65+
"""Check values, if string, convert to nn.Module."""
66+
if is_module(value):
67+
return value
68+
if isinstance(value, str):
69+
return instantiate_module(value)
70+
raise ValueError(f"{info.field_name} must be str or nn.Module")
71+
72+
@field_validator("se", "sa")
5173
def set_se( # pylint: disable=no-self-argument
52-
cls, value: Union[bool, type[nn.Module]]
53-
) -> Union[bool, type[nn.Module]]:
54-
if value:
55-
if isinstance(value, (int, bool)):
56-
return SEModule
57-
return value
58-
59-
@field_validator("sa")
60-
def set_sa( # pylint: disable=no-self-argument
61-
cls, value: Union[bool, type[nn.Module]]
62-
) -> Union[bool, type[nn.Module]]:
63-
if value:
64-
if isinstance(value, (int, bool)):
65-
return SimpleSelfAttention # default: ks=1, sym=sym
66-
return value
74+
cls, value: Union[bool, type[nn.Module]], info: FieldValidationInfo,
75+
) -> Union[type[nn.Module], Callable[[], nn.Module]]:
76+
if isinstance(value, (int, bool)):
77+
return DEFAULT_SE_SA[info.field_name]
78+
if is_module(value):
79+
return value
80+
raise ValueError(f"{info.field_name} must be bool or nn.Module")
6781

6882
@field_validator("se_module", "se_reduction") # pragma: no cover
6983
def deprecation_warning( # pylint: disable=no-self-argument

0 commit comments

Comments
 (0)