Skip to content

Commit 03abb7f

Browse files
authoredJan 28, 2022
Merge pull request #214 from bioimage-io/build-spec-dep
Build spec dep
2 parents 902262b + 1680ee8 commit 03abb7f

File tree

3 files changed

+97
-63
lines changed

3 files changed

+97
-63
lines changed
 

‎bioimageio/core/build_spec/add_weights.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from pathlib import Path
33
from shutil import copyfile
4-
from typing import Dict, Optional, Union
4+
from typing import Dict, Optional, Union, List
55

66
from bioimageio.core import export_resource_package, load_raw_resource_description
77
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
@@ -18,7 +18,8 @@ def add_weights(
1818
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
1919
tensorflow_version: Optional[str] = None,
2020
opset_version: Optional[str] = None,
21-
**weight_kwargs,
21+
pytorch_version: Optional[str] = None,
22+
attachments: Optional[Dict[str, Union[str, List[str]]]] = None,
2223
):
2324
"""Add weight entry to bioimage.io model.
2425
@@ -31,11 +32,10 @@ def add_weights(
3132
Only required for models with pytorch_state_dict weight format.
3233
model_kwargs: the keyword arguments for the model class.
3334
Only required for models with pytorch_state_dict weight format.
34-
tensorflow_version: the tensorflow version used for training the model.
35-
Only requred for models with tensorflow or keras weight format.
36-
opset_version: the opset version used in this model.
37-
Only requred for models with onnx weight format.
38-
weight_kwargs: additional keyword arguments for the weight.
35+
tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
36+
opset_version: the opset version for this model. Only for onnx weights.
37+
pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
38+
attachments: extra weight specific attachments.
3939
"""
4040
model = load_raw_resource_description(model)
4141

@@ -53,7 +53,8 @@ def add_weights(
5353
model_kwargs=model_kwargs,
5454
tensorflow_version=tensorflow_version,
5555
opset_version=opset_version,
56-
**weight_kwargs,
56+
pytorch_version=pytorch_version,
57+
attachments=attachments,
5758
)
5859
model.weights.update(new_weights)
5960

‎bioimageio/core/build_spec/build_model.py

+85-51
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from pathlib import Path
55
from typing import Any, Dict, List, Optional, Tuple, Union
6+
from warnings import warn
67

78
import imageio
89
import numpy as np
@@ -73,6 +74,22 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
7374
return weight_kwargs, tmp_archtecture
7475

7576

77+
def _get_attachments(attachments, root):
78+
assert isinstance(attachments, dict)
79+
if "files" in attachments:
80+
afiles = attachments["files"]
81+
if isinstance(afiles, str):
82+
afiles = [afiles]
83+
84+
if isinstance(afiles, list):
85+
afiles = _ensure_local_or_url(afiles, root)
86+
else:
87+
raise TypeError(attachments)
88+
89+
attachments["files"] = afiles
90+
return attachments
91+
92+
7693
def _get_weights(
7794
original_weight_source,
7895
weight_type,
@@ -81,67 +98,94 @@ def _get_weights(
8198
model_kwargs=None,
8299
tensorflow_version=None,
83100
opset_version=None,
101+
pytorch_version=None,
84102
dependencies=None,
85-
**kwargs,
103+
attachments=None,
86104
):
87105
weight_path = resolve_source(original_weight_source, root)
88106
if weight_type is None:
89107
weight_type = _infer_weight_type(weight_path)
90108
weight_hash = _get_hash(weight_path)
91109

92-
attachments = {"attachments": kwargs["weight_attachments"]} if "weight_attachments" in kwargs else {}
93110
weight_types = model_spec.raw_nodes.WeightsFormat
94111
weight_source = _ensure_local_or_url(original_weight_source, root)
95112

113+
weight_kwargs = {"source": weight_source, "sha256": weight_hash}
114+
if attachments is not None:
115+
weight_kwargs["attachments"] = _get_attachments(attachments, root)
116+
if dependencies is not None:
117+
weight_kwargs["dependencies"] = _get_dependencies(dependencies, root)
118+
96119
tmp_archtecture = None
97120
if weight_type == "pytorch_state_dict":
98121
# pytorch-state-dict -> we need an architecture definition
99-
weight_kwargs, tmp_file = _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root)
100-
weight_kwargs.update(**attachments)
101-
weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(
102-
source=weight_source, sha256=weight_hash, **weight_kwargs
103-
)
104-
if dependencies is not None:
105-
weight_kwargs["dependencies"] = _get_dependencies(dependencies, root)
122+
pytorch_weight_kwargs, tmp_file = _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root)
123+
weight_kwargs.update(**pytorch_weight_kwargs)
124+
if pytorch_version is not None:
125+
weight_kwargs["pytorch_version"] = pytorch_version
126+
elif dependencies is None:
127+
warn(
128+
"You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
129+
"It may not be possible to create an environmnet where your model can be used."
130+
)
131+
weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(**weight_kwargs)
106132

107133
elif weight_type == "onnx":
108-
if opset_version is None:
109-
raise ValueError("opset_version needs to be passed for building an onnx model")
110-
weights = model_spec.raw_nodes.OnnxWeightsEntry(
111-
source=weight_source, sha256=weight_hash, opset_version=opset_version, **attachments
112-
)
134+
if opset_version is not None:
135+
weight_kwargs["opset_version"] = opset_version
136+
elif dependencies is None:
137+
warn(
138+
"You are building an onnx model but have neither passed dependencies nor the opset_version."
139+
"It may not be possible to create an environmnet where your model can be used."
140+
)
141+
weights = model_spec.raw_nodes.OnnxWeightsEntry(**weight_kwargs)
113142

114143
elif weight_type == "torchscript":
115-
weights = model_spec.raw_nodes.TorchscriptWeightsEntry(source=weight_source, sha256=weight_hash, **attachments)
144+
if pytorch_version is not None:
145+
weight_kwargs["pytorch_version"] = pytorch_version
146+
elif dependencies is None:
147+
warn(
148+
"You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
149+
"It may not be possible to create an environmnet where your model can be used."
150+
)
151+
weights = model_spec.raw_nodes.TorchscriptWeightsEntry(**weight_kwargs)
116152

117153
elif weight_type == "keras_hdf5":
118-
if tensorflow_version is None:
119-
raise ValueError("tensorflow_version needs to be passed for building a keras model")
120-
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
121-
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
122-
)
154+
if tensorflow_version is not None:
155+
weight_kwargs["tensorflow_version"] = tensorflow_version
156+
elif dependencies is None:
157+
warn(
158+
"You are building a keras model but have neither passed dependencies nor the tensorflow_version."
159+
"It may not be possible to create an environmnet where your model can be used."
160+
)
161+
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(**weight_kwargs)
123162

124163
elif weight_type == "tensorflow_saved_model_bundle":
125-
if tensorflow_version is None:
126-
raise ValueError("tensorflow_version needs to be passed for building a tensorflow model")
127-
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
128-
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
129-
)
164+
if tensorflow_version is not None:
165+
weight_kwargs["tensorflow_version"] = tensorflow_version
166+
elif dependencies is None:
167+
warn(
168+
"You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
169+
"It may not be possible to create an environmnet where your model can be used."
170+
)
171+
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(**weight_kwargs)
130172

131173
elif weight_type == "tensorflow_js":
132-
if tensorflow_version is None:
133-
raise ValueError("tensorflow_version needs to be passed for building a tensorflow_js model")
134-
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
135-
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
136-
)
174+
if tensorflow_version is not None:
175+
weight_kwargs["tensorflow_version"] = tensorflow_version
176+
elif dependencies is None:
177+
warn(
178+
"You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
179+
"It may not be possible to create an environmnet where your model can be used."
180+
)
181+
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(**weight_kwargs)
137182

138183
elif weight_type in weight_types:
139184
raise ValueError(f"Weight type {weight_type} is not supported yet in 'build_spec'")
140185
else:
141186
raise ValueError(f"Invalid weight type {weight_type}, expect one of {weight_types}")
142187

143-
weights = {weight_type: weights}
144-
return weights, tmp_archtecture
188+
return {weight_type: weights}, tmp_archtecture
145189

146190

147191
def _get_data_range(data_range, dtype):
@@ -563,7 +607,8 @@ def build_model(
563607
add_deepimagej_config: bool = False,
564608
tensorflow_version: Optional[str] = None,
565609
opset_version: Optional[int] = None,
566-
**weight_kwargs,
610+
pytorch_version: Optional[str] = None,
611+
weight_attachments: Optional[Dict[str, Union[str, List[str]]]] = None,
567612
):
568613
"""Create a zipped bioimage.io model.
569614
@@ -635,30 +680,18 @@ def build_model(
635680
dependencies: relative path to file with dependencies for this model.
636681
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
637682
add_deepimagej_config: add the deepimagej config to the model.
638-
tensorflow_version: the tensorflow version used for training the model.
639-
Only requred for models with tensorflow or keras weight format.
640-
opset_version: the opset version used in this model.
641-
Only requred for models with onnx weight format.
642-
weight_kwargs: additional keyword arguments for this weight type.
683+
tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
684+
opset_version: the opset version for this model. Only for onnx weights.
685+
pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
686+
weight_attachments: extra weight specific attachments.
643687
"""
644688
assert architecture is None or isinstance(architecture, str)
645689
if root is None:
646690
root = "."
647691
root = Path(root)
648692

649693
if attachments is not None:
650-
assert isinstance(attachments, dict)
651-
if "files" in attachments:
652-
afiles = attachments["files"]
653-
if isinstance(afiles, str):
654-
afiles = [afiles]
655-
656-
if isinstance(afiles, list):
657-
afiles = _ensure_local_or_url(afiles, root)
658-
else:
659-
raise TypeError(attachments)
660-
661-
attachments["files"] = afiles
694+
attachments = _get_attachments(attachments, root)
662695

663696
#
664697
# generate the model specific fields
@@ -750,8 +783,9 @@ def build_model(
750783
model_kwargs,
751784
tensorflow_version=tensorflow_version,
752785
opset_version=opset_version,
786+
pytorch_version=pytorch_version,
753787
dependencies=dependencies,
754-
**weight_kwargs,
788+
attachments=weight_attachments,
755789
)
756790

757791
# validate the sample inputs and outputs (if given)

‎tests/build_spec/test_build_spec.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,21 @@ def _test_build_spec(
2727

2828
cite = {entry.text: entry.doi if entry.url is missing else entry.url for entry in model_spec.cite}
2929

30-
dep_file = None
30+
weight_spec = model_spec.weights[weight_type]
31+
dep_file = None if weight_spec.dependencies is missing else resolve_source(weight_spec.dependencies.file, root)
3132
if weight_type == "pytorch_state_dict":
32-
weight_spec = model_spec.weights["pytorch_state_dict"]
3333
model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs
3434
architecture = str(weight_spec.architecture)
3535
if use_absoloute_arch_path:
3636
arch_path, cls_name = architecture.split(":")
3737
arch_path = os.path.abspath(os.path.join(root, arch_path))
3838
assert os.path.exists(arch_path)
3939
architecture = f"{arch_path}:{cls_name}"
40-
dep_file = None if weight_spec.dependencies is missing else resolve_source(weight_spec.dependencies.file, root)
4140
weight_type_ = None # the weight type can be auto-detected
4241
elif weight_type == "torchscript":
4342
architecture = None
4443
model_kwargs = None
45-
weight_type_ = "torchscript" # the weight type CANNOT be auto-detcted
44+
weight_type_ = "torchscript" # the weight type CANNOT be auto-detected
4645
else:
4746
architecture = None
4847
model_kwargs = None

0 commit comments

Comments
 (0)
Please sign in to comment.