14
14
from bioimageio .core import export_resource_package , load_raw_resource_description
15
15
from bioimageio .core .resource_io .nodes import URI
16
16
from bioimageio .core .resource_io .utils import resolve_local_source , resolve_source
17
+ from bioimageio .spec .shared import fields
17
18
from bioimageio .spec .shared .raw_nodes import ImportableSourceFile , ImportableModule
18
19
19
20
try :
@@ -60,16 +61,10 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
60
61
tmp_archtecture = None
61
62
weight_kwargs = {"kwargs" : model_kwargs } if model_kwargs else {}
62
63
if ":" in architecture :
63
- arch_file , callable_name = architecture .replace ("::" , ":" ).split (":" )
64
-
65
- # this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
66
- if os .path .isabs (arch_file ):
67
- tmp_archtecture = Path ("this_model_architecture.py" )
68
- copyfile (arch_file , root / tmp_archtecture )
69
- arch = ImportableSourceFile (callable_name , tmp_archtecture )
70
- else :
71
- arch = ImportableSourceFile (callable_name , Path (arch_file ))
72
-
64
+ # note: path itself might include : for absolute paths in windows
65
+ * arch_file_parts , callable_name = architecture .replace ("::" , ":" ).split (":" )
66
+ arch_file = _ensure_local (":" .join (arch_file_parts ), root )
67
+ arch = ImportableSourceFile (callable_name , arch_file )
73
68
arch_hash = _get_hash (root / arch .source_file )
74
69
weight_kwargs ["architecture_sha256" ] = arch_hash
75
70
else :
@@ -122,30 +117,21 @@ def _get_weights(
122
117
if tensorflow_version is None :
123
118
raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
124
119
weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
125
- source = weight_source ,
126
- sha256 = weight_hash ,
127
- tensorflow_version = tensorflow_version ,
128
- ** attachments ,
120
+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
129
121
)
130
122
131
123
elif weight_type == "tensorflow_saved_model_bundle" :
132
124
if tensorflow_version is None :
133
125
raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
134
126
weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
135
- source = weight_source ,
136
- sha256 = weight_hash ,
137
- tensorflow_version = tensorflow_version ,
138
- ** attachments ,
127
+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
139
128
)
140
129
141
130
elif weight_type == "tensorflow_js" :
142
131
if tensorflow_version is None :
143
132
raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
144
133
weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
145
- source = weight_source ,
146
- sha256 = weight_hash ,
147
- tensorflow_version = tensorflow_version ,
148
- ** attachments ,
134
+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
149
135
)
150
136
151
137
elif weight_type in weight_types :
@@ -363,7 +349,7 @@ def get_size(path):
363
349
"allow_tiling" : True ,
364
350
"model_keys" : None ,
365
351
}
366
- return {"deepimagej" : config }, attachments
352
+ return {"deepimagej" : config }, [ Path ( a ) for a in attachments ]
367
353
368
354
369
355
def _write_sample_data (input_paths , output_paths , input_axes , output_axes , export_folder : Path ):
@@ -518,9 +504,8 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
518
504
return [_ensure_local_or_url (s , root ) for s in source ]
519
505
520
506
local_source = resolve_local_source (source , root )
521
- local_source = resolve_local_source (
522
- local_source , root , None if isinstance (local_source , URI ) else root / local_source .name
523
- )
507
+ if not isinstance (local_source , URI ):
508
+ local_source = resolve_local_source (local_source , root , root / local_source .name )
524
509
return local_source .relative_to (root )
525
510
526
511
@@ -653,10 +638,25 @@ def build_model(
653
638
Only requred for models with onnx weight format.
654
639
weight_kwargs: additional keyword arguments for this weight type.
655
640
"""
641
+ assert architecture is None or isinstance (architecture , str )
656
642
if root is None :
657
643
root = "."
658
644
root = Path (root )
659
645
646
+ if attachments is not None :
647
+ assert isinstance (attachments , dict )
648
+ if "files" in attachments :
649
+ afiles = attachments ["files" ]
650
+ if isinstance (afiles , str ):
651
+ afiles = [afiles ]
652
+
653
+ if isinstance (afiles , list ):
654
+ afiles = _ensure_local_or_url (afiles , root )
655
+ else :
656
+ raise TypeError (attachments )
657
+
658
+ attachments ["files" ] = afiles
659
+
660
660
#
661
661
# generate the model specific fields
662
662
#
@@ -783,7 +783,7 @@ def build_model(
783
783
elif "files" not in attachments :
784
784
attachments ["files" ] = ij_attachments
785
785
else :
786
- attachments ["files" ]. extend ( ij_attachments )
786
+ attachments ["files" ] = list ( set ( attachments [ "files" ]) | set ( ij_attachments ) )
787
787
788
788
if links is None :
789
789
links = ["deepimagej/deepimagej" ]
@@ -803,7 +803,6 @@ def build_model(
803
803
804
804
# optional kwargs, don't pass them if none
805
805
optional_kwargs = {
806
- "attachments" : attachments ,
807
806
"config" : config ,
808
807
"git_repo" : git_repo ,
809
808
"packaged_by" : packaged_by ,
@@ -814,13 +813,15 @@ def build_model(
814
813
}
815
814
kwargs = {k : v for k , v in optional_kwargs .items () if v is not None }
816
815
816
+ if attachments is not None :
817
+ kwargs ["attachments" ] = model_spec .raw_nodes .Attachments (** attachments )
817
818
if dependencies is not None :
818
819
kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
820
+ if maintainers is not None :
821
+ kwargs ["maintainers" ] = [model_spec .raw_nodes .Maintainer (** m ) for m in maintainers ]
819
822
if parent is not None :
820
823
assert len (parent ) == 2
821
824
kwargs ["parent" ] = {"uri" : parent [0 ], "sha256" : parent [1 ]}
822
- if maintainers is not None :
823
- kwargs ["maintainers" ] = [model_spec .raw_nodes .Maintainer (** m ) for m in maintainers ]
824
825
825
826
try :
826
827
model = model_spec .raw_nodes .Model (
0 commit comments