Skip to content

Commit a963290

Browse files
committed
make body
1 parent 6c8a6e9 commit a963290

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/model_constructor/model_constructor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
155155
)
156156

157157

158-
def make_body(cfg: ModelCfg) -> nn.Sequential:
158+
def make_body(
159+
cfg: ModelCfg,
160+
layer_constructor: Callable[[ModelCfg, int], nn.Sequential] = make_layer,
161+
) -> nn.Sequential:
159162
"""Create model body."""
163+
if hasattr(cfg, "make_layer"):
164+
layer_constructor = cfg.make_layer # type: ignore
160165
return nn_seq(
161-
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
166+
(f"l_{layer_num}", layer_constructor(cfg, layer_num)) # type: ignore
162167
for layer_num in range(len(cfg.layers))
163168
)
164169

0 commit comments

Comments
 (0)