|
3 | 3 | from typing import Any, Callable, Optional, Union
|
4 | 4 |
|
5 | 5 | from pydantic import field_validator
|
| 6 | +from pydantic_core.core_schema import FieldValidationInfo |
6 | 7 | from torch import nn
|
7 | 8 |
|
8 | 9 | from .blocks import BasicBlock, BottleneckBlock
|
9 |
| -from .helpers import Cfg, ListStrMod, ModSeq, init_cnn, nn_seq |
| 10 | +from .helpers import (Cfg, ListStrMod, ModSeq, init_cnn, instantiate_module, |
| 11 | + is_module, nn_seq) |
10 | 12 | from .layers import ConvBnAct, SEModule, SimpleSelfAttention
|
11 | 13 |
|
12 | 14 | __all__ = [
|
|
17 | 19 | ]
|
18 | 20 |
|
19 | 21 |
|
| 22 | +DEFAULT_SE_SA = { |
| 23 | + "se": SEModule, |
| 24 | + "sa": SimpleSelfAttention, |
| 25 | +} |
| 26 | + |
| 27 | + |
| 28 | +nnModule = Union[type[nn.Module], Callable[[], nn.Module], str] |
| 29 | + |
| 30 | + |
20 | 31 | class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
|
21 | 32 | """Model constructor Config. As default - xresnet18"""
|
22 | 33 |
|
23 | 34 | name: Optional[str] = None
|
24 | 35 | in_chans: int = 3
|
25 | 36 | num_classes: int = 1000
|
26 |
| - block: type[nn.Module] = BasicBlock |
27 |
| - conv_layer: type[nn.Module] = ConvBnAct |
| 37 | + block: nnModule = BasicBlock |
| 38 | + conv_layer: nnModule = ConvBnAct |
28 | 39 | block_sizes: list[int] = [64, 128, 256, 512]
|
29 | 40 | layers: list[int] = [2, 2, 2, 2]
|
30 |
| - norm: type[nn.Module] = nn.BatchNorm2d |
31 |
| - act_fn: type[nn.Module] = nn.ReLU |
32 |
| - pool: Optional[Callable[[Any], nn.Module]] = None |
| 41 | + norm: nnModule = nn.BatchNorm2d |
| 42 | + act_fn: nnModule = nn.ReLU |
| 43 | + pool: Optional[nnModule] = None |
33 | 44 | expansion: int = 1
|
34 | 45 | groups: int = 1
|
35 | 46 | dw: bool = False
|
36 | 47 | div_groups: Optional[int] = None
|
37 |
| - sa: Union[bool, type[nn.Module]] = False |
38 |
| - se: Union[bool, type[nn.Module]] = False |
| 48 | + sa: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False |
| 49 | + se: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False |
39 | 50 | se_module: Optional[bool] = None
|
40 | 51 | se_reduction: Optional[int] = None
|
41 | 52 | bn_1st: bool = True
|
42 | 53 | zero_bn: bool = True
|
43 | 54 | stem_stride_on: int = 0
|
44 | 55 | stem_sizes: list[int] = [64]
|
45 |
| - stem_pool: Optional[Callable[[], nn.Module]] = partial( |
| 56 | + stem_pool: Optional[nnModule] = partial( |
46 | 57 | nn.MaxPool2d, kernel_size=3, stride=2, padding=1
|
47 | 58 | )
|
48 | 59 | stem_bn_end: bool = False
|
49 | 60 |
|
50 |
| - @field_validator("se") |
| 61 | + @field_validator("act_fn", "block", "conv_layer", "norm", "pool", "stem_pool") |
| 62 | + def set_modules( # pylint: disable=no-self-argument |
| 63 | + cls, value: Union[type[nn.Module], str], info: FieldValidationInfo, |
| 64 | + ) -> Union[type[nn.Module], Callable[[], nn.Module]]: |
| 65 | + """Check values, if string, convert to nn.Module.""" |
| 66 | + if is_module(value): |
| 67 | + return value |
| 68 | + if isinstance(value, str): |
| 69 | + return instantiate_module(value) |
| 70 | + raise ValueError(f"{info.field_name} must be str or nn.Module") |
| 71 | + |
| 72 | + @field_validator("se", "sa") |
51 | 73 | def set_se( # pylint: disable=no-self-argument
|
52 |
| - cls, value: Union[bool, type[nn.Module]] |
53 |
| - ) -> Union[bool, type[nn.Module]]: |
54 |
| - if value: |
55 |
| - if isinstance(value, (int, bool)): |
56 |
| - return SEModule |
57 |
| - return value |
58 |
| - |
59 |
| - @field_validator("sa") |
60 |
| - def set_sa( # pylint: disable=no-self-argument |
61 |
| - cls, value: Union[bool, type[nn.Module]] |
62 |
| - ) -> Union[bool, type[nn.Module]]: |
63 |
| - if value: |
64 |
| - if isinstance(value, (int, bool)): |
65 |
| - return SimpleSelfAttention # default: ks=1, sym=sym |
66 |
| - return value |
| 74 | + cls, value: Union[bool, type[nn.Module]], info: FieldValidationInfo, |
| 75 | + ) -> Union[type[nn.Module], Callable[[], nn.Module]]: |
| 76 | + if isinstance(value, (int, bool)): |
| 77 | + return DEFAULT_SE_SA[info.field_name] |
| 78 | + if is_module(value): |
| 79 | + return value |
| 80 | + raise ValueError(f"{info.field_name} must be bool or nn.Module") |
67 | 81 |
|
68 | 82 | @field_validator("se_module", "se_reduction") # pragma: no cover
|
69 | 83 | def deprecation_warning( # pylint: disable=no-self-argument
|
|
0 commit comments