@@ -47,6 +47,31 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
47
47
)
48
48
stem_bn_end : bool = False
49
49
50
+ @field_validator ("se" )
51
+ def set_se ( # pylint: disable=no-self-argument
52
+ cls , value : Union [bool , type [nn .Module ]]
53
+ ) -> Union [bool , type [nn .Module ]]:
54
+ if value :
55
+ if isinstance (value , (int , bool )):
56
+ return SEModule
57
+ return value
58
+
59
+ @field_validator ("sa" )
60
+ def set_sa ( # pylint: disable=no-self-argument
61
+ cls , value : Union [bool , type [nn .Module ]]
62
+ ) -> Union [bool , type [nn .Module ]]:
63
+ if value :
64
+ if isinstance (value , (int , bool )):
65
+ return SimpleSelfAttention # default: ks=1, sym=sym
66
+ return value
67
+
68
+ @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
69
+ def deprecation_warning ( # pylint: disable=no-self-argument
70
+ cls , value : Union [bool , int , None ]
71
+ ) -> Union [bool , int , None ]:
72
+ print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
73
+ return value
74
+
50
75
def __repr__ (self ) -> str :
51
76
se_repr = self .se .__name__ if self .se else "False" # type: ignore
52
77
model_name = self .name or self .__class__ .__name__
@@ -143,31 +168,6 @@ class ModelConstructor(ModelCfg):
143
168
make_body : Callable [[ModelCfg ], ModSeq ] = make_body
144
169
make_head : Callable [[ModelCfg ], ModSeq ] = make_head
145
170
146
- @field_validator ("se" )
147
- def set_se ( # pylint: disable=no-self-argument
148
- cls , value : Union [bool , type [nn .Module ]]
149
- ) -> Union [bool , type [nn .Module ]]:
150
- if value :
151
- if isinstance (value , (int , bool )):
152
- return SEModule
153
- return value
154
-
155
- @field_validator ("sa" )
156
- def set_sa ( # pylint: disable=no-self-argument
157
- cls , value : Union [bool , type [nn .Module ]]
158
- ) -> Union [bool , type [nn .Module ]]:
159
- if value :
160
- if isinstance (value , (int , bool )):
161
- return SimpleSelfAttention # default: ks=1, sym=sym
162
- return value
163
-
164
- @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
165
- def deprecation_warning ( # pylint: disable=no-self-argument
166
- cls , value : Union [bool , int , None ]
167
- ) -> Union [bool , int , None ]:
168
- print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
169
- return value
170
-
171
171
@property
172
172
def stem (self ):
173
173
return self .make_stem (self ) # pylint: disable=too-many-function-args
0 commit comments