1
1
import os
2
2
from pathlib import Path
3
3
from shutil import copyfile
4
- from typing import Dict , Optional , Union
4
+ from typing import Dict , Optional , Union , List
5
5
6
6
from bioimageio .core import export_resource_package , load_raw_resource_description
7
7
from bioimageio .spec .shared .raw_nodes import ResourceDescription as RawResourceDescription
@@ -18,7 +18,8 @@ def add_weights(
18
18
model_kwargs : Optional [Dict [str , Union [int , float , str ]]] = None ,
19
19
tensorflow_version : Optional [str ] = None ,
20
20
opset_version : Optional [str ] = None ,
21
- ** weight_kwargs ,
21
+ pytorch_version : Optional [str ] = None ,
22
+ attachments : Optional [Dict [str , Union [str , List [str ]]]] = None ,
22
23
):
23
24
"""Add weight entry to bioimage.io model.
24
25
@@ -31,11 +32,10 @@ def add_weights(
31
32
Only required for models with pytorch_state_dict weight format.
32
33
model_kwargs: the keyword arguments for the model class.
33
34
Only required for models with pytorch_state_dict weight format.
34
- tensorflow_version: the tensorflow version used for training the model.
35
- Only requred for models with tensorflow or keras weight format.
36
- opset_version: the opset version used in this model.
37
- Only requred for models with onnx weight format.
38
- weight_kwargs: additional keyword arguments for the weight.
35
+ tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
36
+ opset_version: the opset version for this model. Only for onnx weights.
37
+ pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
38
+ attachments: extra weight specific attachments.
39
39
"""
40
40
model = load_raw_resource_description (model )
41
41
@@ -53,7 +53,8 @@ def add_weights(
53
53
model_kwargs = model_kwargs ,
54
54
tensorflow_version = tensorflow_version ,
55
55
opset_version = opset_version ,
56
- ** weight_kwargs ,
56
+ pytorch_version = pytorch_version ,
57
+ attachments = attachments ,
57
58
)
58
59
model .weights .update (new_weights )
59
60
0 commit comments