Skip to content

Commit 6dbecde

Browse files
committed
move validation to modelCfg
1 parent 60778b7 commit 6dbecde

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

src/model_constructor/model_constructor.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4747
)
4848
stem_bn_end: bool = False
4949

50+
@field_validator("se")
51+
def set_se( # pylint: disable=no-self-argument
52+
cls, value: Union[bool, type[nn.Module]]
53+
) -> Union[bool, type[nn.Module]]:
54+
if value:
55+
if isinstance(value, (int, bool)):
56+
return SEModule
57+
return value
58+
59+
@field_validator("sa")
60+
def set_sa( # pylint: disable=no-self-argument
61+
cls, value: Union[bool, type[nn.Module]]
62+
) -> Union[bool, type[nn.Module]]:
63+
if value:
64+
if isinstance(value, (int, bool)):
65+
return SimpleSelfAttention # default: ks=1, sym=sym
66+
return value
67+
68+
@field_validator("se_module", "se_reduction") # pragma: no cover
69+
def deprecation_warning( # pylint: disable=no-self-argument
70+
cls, value: Union[bool, int, None]
71+
) -> Union[bool, int, None]:
72+
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.")
73+
return value
74+
5075
def __repr__(self) -> str:
5176
se_repr = self.se.__name__ if self.se else "False" # type: ignore
5277
model_name = self.name or self.__class__.__name__
@@ -143,31 +168,6 @@ class ModelConstructor(ModelCfg):
143168
make_body: Callable[[ModelCfg], ModSeq] = make_body
144169
make_head: Callable[[ModelCfg], ModSeq] = make_head
145170

146-
@field_validator("se")
147-
def set_se( # pylint: disable=no-self-argument
148-
cls, value: Union[bool, type[nn.Module]]
149-
) -> Union[bool, type[nn.Module]]:
150-
if value:
151-
if isinstance(value, (int, bool)):
152-
return SEModule
153-
return value
154-
155-
@field_validator("sa")
156-
def set_sa( # pylint: disable=no-self-argument
157-
cls, value: Union[bool, type[nn.Module]]
158-
) -> Union[bool, type[nn.Module]]:
159-
if value:
160-
if isinstance(value, (int, bool)):
161-
return SimpleSelfAttention # default: ks=1, sym=sym
162-
return value
163-
164-
@field_validator("se_module", "se_reduction") # pragma: no cover
165-
def deprecation_warning( # pylint: disable=no-self-argument
166-
cls, value: Union[bool, int, None]
167-
) -> Union[bool, int, None]:
168-
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.")
169-
return value
170-
171171
@property
172172
def stem(self):
173173
return self.make_stem(self) # pylint: disable=too-many-function-args

0 commit comments

Comments
 (0)