3
3
import os
4
4
from pathlib import Path
5
5
from typing import Any , Dict , List , Optional , Tuple , Union
6
+ from warnings import warn
6
7
7
8
import imageio
8
9
import numpy as np
@@ -73,6 +74,22 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
73
74
return weight_kwargs , tmp_archtecture
74
75
75
76
77
+ def _get_attachments (attachments , root ):
78
+ assert isinstance (attachments , dict )
79
+ if "files" in attachments :
80
+ afiles = attachments ["files" ]
81
+ if isinstance (afiles , str ):
82
+ afiles = [afiles ]
83
+
84
+ if isinstance (afiles , list ):
85
+ afiles = _ensure_local_or_url (afiles , root )
86
+ else :
87
+ raise TypeError (attachments )
88
+
89
+ attachments ["files" ] = afiles
90
+ return attachments
91
+
92
+
76
93
def _get_weights (
77
94
original_weight_source ,
78
95
weight_type ,
@@ -81,67 +98,94 @@ def _get_weights(
81
98
model_kwargs = None ,
82
99
tensorflow_version = None ,
83
100
opset_version = None ,
101
+ pytorch_version = None ,
84
102
dependencies = None ,
85
- ** kwargs ,
103
+ attachments = None ,
86
104
):
87
105
weight_path = resolve_source (original_weight_source , root )
88
106
if weight_type is None :
89
107
weight_type = _infer_weight_type (weight_path )
90
108
weight_hash = _get_hash (weight_path )
91
109
92
- attachments = {"attachments" : kwargs ["weight_attachments" ]} if "weight_attachments" in kwargs else {}
93
110
weight_types = model_spec .raw_nodes .WeightsFormat
94
111
weight_source = _ensure_local_or_url (original_weight_source , root )
95
112
113
+ weight_kwargs = {"source" : weight_source , "sha256" : weight_hash }
114
+ if attachments is not None :
115
+ weight_kwargs ["attachments" ] = _get_attachments (attachments , root )
116
+ if dependencies is not None :
117
+ weight_kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
118
+
96
119
tmp_archtecture = None
97
120
if weight_type == "pytorch_state_dict" :
98
121
# pytorch-state-dict -> we need an architecture definition
99
- weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
100
- weight_kwargs .update (** attachments )
101
- weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (
102
- source = weight_source , sha256 = weight_hash , ** weight_kwargs
103
- )
104
- if dependencies is not None :
105
- weight_kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
122
+ pytorch_weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
123
+ weight_kwargs .update (** pytorch_weight_kwargs )
124
+ if pytorch_version is not None :
125
+ weight_kwargs ["pytorch_version" ] = pytorch_version
126
+ elif dependencies is None :
127
+ warn (
128
+ "You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
129
+ "It may not be possible to create an environmnet where your model can be used."
130
+ )
131
+ weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (** weight_kwargs )
106
132
107
133
elif weight_type == "onnx" :
108
- if opset_version is None :
109
- raise ValueError ("opset_version needs to be passed for building an onnx model" )
110
- weights = model_spec .raw_nodes .OnnxWeightsEntry (
111
- source = weight_source , sha256 = weight_hash , opset_version = opset_version , ** attachments
112
- )
134
+ if opset_version is not None :
135
+ weight_kwargs ["opset_version" ] = opset_version
136
+ elif dependencies is None :
137
+ warn (
138
+ "You are building an onnx model but have neither passed dependencies nor the opset_version."
139
+ "It may not be possible to create an environmnet where your model can be used."
140
+ )
141
+ weights = model_spec .raw_nodes .OnnxWeightsEntry (** weight_kwargs )
113
142
114
143
elif weight_type == "torchscript" :
115
- weights = model_spec .raw_nodes .TorchscriptWeightsEntry (source = weight_source , sha256 = weight_hash , ** attachments )
144
+ if pytorch_version is not None :
145
+ weight_kwargs ["pytorch_version" ] = pytorch_version
146
+ elif dependencies is None :
147
+ warn (
148
+ "You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
149
+ "It may not be possible to create an environmnet where your model can be used."
150
+ )
151
+ weights = model_spec .raw_nodes .TorchscriptWeightsEntry (** weight_kwargs )
116
152
117
153
elif weight_type == "keras_hdf5" :
118
- if tensorflow_version is None :
119
- raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
120
- weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
121
- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
122
- )
154
+ if tensorflow_version is not None :
155
+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
156
+ elif dependencies is None :
157
+ warn (
158
+ "You are building a keras model but have neither passed dependencies nor the tensorflow_version."
159
+ "It may not be possible to create an environmnet where your model can be used."
160
+ )
161
+ weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (** weight_kwargs )
123
162
124
163
elif weight_type == "tensorflow_saved_model_bundle" :
125
- if tensorflow_version is None :
126
- raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
127
- weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
128
- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
129
- )
164
+ if tensorflow_version is not None :
165
+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
166
+ elif dependencies is None :
167
+ warn (
168
+ "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
169
+ "It may not be possible to create an environmnet where your model can be used."
170
+ )
171
+ weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (** weight_kwargs )
130
172
131
173
elif weight_type == "tensorflow_js" :
132
- if tensorflow_version is None :
133
- raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
134
- weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
135
- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
136
- )
174
+ if tensorflow_version is not None :
175
+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
176
+ elif dependencies is None :
177
+ warn (
178
+ "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
179
+ "It may not be possible to create an environmnet where your model can be used."
180
+ )
181
+ weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (** weight_kwargs )
137
182
138
183
elif weight_type in weight_types :
139
184
raise ValueError (f"Weight type { weight_type } is not supported yet in 'build_spec'" )
140
185
else :
141
186
raise ValueError (f"Invalid weight type { weight_type } , expect one of { weight_types } " )
142
187
143
- weights = {weight_type : weights }
144
- return weights , tmp_archtecture
188
+ return {weight_type : weights }, tmp_archtecture
145
189
146
190
147
191
def _get_data_range (data_range , dtype ):
@@ -563,7 +607,8 @@ def build_model(
563
607
add_deepimagej_config : bool = False ,
564
608
tensorflow_version : Optional [str ] = None ,
565
609
opset_version : Optional [int ] = None ,
566
- ** weight_kwargs ,
610
+ pytorch_version : Optional [str ] = None ,
611
+ weight_attachments : Optional [Dict [str , Union [str , List [str ]]]] = None ,
567
612
):
568
613
"""Create a zipped bioimage.io model.
569
614
@@ -635,30 +680,18 @@ def build_model(
635
680
dependencies: relative path to file with dependencies for this model.
636
681
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
637
682
add_deepimagej_config: add the deepimagej config to the model.
638
- tensorflow_version: the tensorflow version used for training the model.
639
- Only requred for models with tensorflow or keras weight format.
640
- opset_version: the opset version used in this model.
641
- Only requred for models with onnx weight format.
642
- weight_kwargs: additional keyword arguments for this weight type.
683
+ tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
684
+ opset_version: the opset version for this model. Only for onnx weights.
685
+ pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
686
+ weight_attachments: extra weight specific attachments.
643
687
"""
644
688
assert architecture is None or isinstance (architecture , str )
645
689
if root is None :
646
690
root = "."
647
691
root = Path (root )
648
692
649
693
if attachments is not None :
650
- assert isinstance (attachments , dict )
651
- if "files" in attachments :
652
- afiles = attachments ["files" ]
653
- if isinstance (afiles , str ):
654
- afiles = [afiles ]
655
-
656
- if isinstance (afiles , list ):
657
- afiles = _ensure_local_or_url (afiles , root )
658
- else :
659
- raise TypeError (attachments )
660
-
661
- attachments ["files" ] = afiles
694
+ attachments = _get_attachments (attachments , root )
662
695
663
696
#
664
697
# generate the model specific fields
@@ -750,8 +783,9 @@ def build_model(
750
783
model_kwargs ,
751
784
tensorflow_version = tensorflow_version ,
752
785
opset_version = opset_version ,
786
+ pytorch_version = pytorch_version ,
753
787
dependencies = dependencies ,
754
- ** weight_kwargs ,
788
+ attachments = weight_attachments ,
755
789
)
756
790
757
791
# validate the sample inputs and outputs (if given)
0 commit comments