Skip to content

Commit 1680ee8

Browse files
Adapt add_weights to updated spec
1 parent 9e774f2 commit 1680ee8

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

bioimageio/core/build_spec/add_weights.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from pathlib import Path
33
from shutil import copyfile
4-
from typing import Dict, Optional, Union
4+
from typing import Dict, Optional, Union, List
55

66
from bioimageio.core import export_resource_package, load_raw_resource_description
77
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
@@ -18,7 +18,8 @@ def add_weights(
1818
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
1919
tensorflow_version: Optional[str] = None,
2020
opset_version: Optional[str] = None,
21-
**weight_kwargs,
21+
pytorch_version: Optional[str] = None,
22+
attachments: Optional[Dict[str, Union[str, List[str]]]] = None,
2223
):
2324
"""Add weight entry to bioimage.io model.
2425
@@ -31,11 +32,10 @@ def add_weights(
3132
Only required for models with pytorch_state_dict weight format.
3233
model_kwargs: the keyword arguments for the model class.
3334
Only required for models with pytorch_state_dict weight format.
34-
tensorflow_version: the tensorflow version used for training the model.
35-
Only requred for models with tensorflow or keras weight format.
36-
opset_version: the opset version used in this model.
37-
Only requred for models with onnx weight format.
38-
weight_kwargs: additional keyword arguments for the weight.
35+
tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
36+
opset_version: the opset version for this model. Only for onnx weights.
37+
pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
38+
attachments: extra weight specific attachments.
3939
"""
4040
model = load_raw_resource_description(model)
4141

@@ -53,7 +53,8 @@ def add_weights(
5353
model_kwargs=model_kwargs,
5454
tensorflow_version=tensorflow_version,
5555
opset_version=opset_version,
56-
**weight_kwargs,
56+
pytorch_version=pytorch_version,
57+
attachments=attachments,
5758
)
5859
model.weights.update(new_weights)
5960

0 commit comments

Comments
 (0)