Skip to content

Commit 0780b53

Browse files
authored
Merge pull request #185 from bioimage-io/fix-attachments
Test for issue with attachments
2 parents b5ac250 + 18ab64c commit 0780b53

File tree

2 files changed

+45
-31
lines changed

2 files changed

+45
-31
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from bioimageio.core import export_resource_package, load_raw_resource_description
1515
from bioimageio.core.resource_io.nodes import URI
1616
from bioimageio.core.resource_io.utils import resolve_local_source, resolve_source
17+
from bioimageio.spec.shared import fields
1718
from bioimageio.spec.shared.raw_nodes import ImportableSourceFile, ImportableModule
1819

1920
try:
@@ -60,16 +61,10 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
6061
tmp_archtecture = None
6162
weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {}
6263
if ":" in architecture:
63-
arch_file, callable_name = architecture.replace("::", ":").split(":")
64-
65-
# this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
66-
if os.path.isabs(arch_file):
67-
tmp_archtecture = Path("this_model_architecture.py")
68-
copyfile(arch_file, root / tmp_archtecture)
69-
arch = ImportableSourceFile(callable_name, tmp_archtecture)
70-
else:
71-
arch = ImportableSourceFile(callable_name, Path(arch_file))
72-
64+
# note: path itself might include : for absolute paths in windows
65+
*arch_file_parts, callable_name = architecture.replace("::", ":").split(":")
66+
arch_file = _ensure_local(":".join(arch_file_parts), root)
67+
arch = ImportableSourceFile(callable_name, arch_file)
7368
arch_hash = _get_hash(root / arch.source_file)
7469
weight_kwargs["architecture_sha256"] = arch_hash
7570
else:
@@ -122,30 +117,21 @@ def _get_weights(
122117
if tensorflow_version is None:
123118
raise ValueError("tensorflow_version needs to be passed for building a keras model")
124119
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
125-
source=weight_source,
126-
sha256=weight_hash,
127-
tensorflow_version=tensorflow_version,
128-
**attachments,
120+
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
129121
)
130122

131123
elif weight_type == "tensorflow_saved_model_bundle":
132124
if tensorflow_version is None:
133125
raise ValueError("tensorflow_version needs to be passed for building a tensorflow model")
134126
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
135-
source=weight_source,
136-
sha256=weight_hash,
137-
tensorflow_version=tensorflow_version,
138-
**attachments,
127+
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
139128
)
140129

141130
elif weight_type == "tensorflow_js":
142131
if tensorflow_version is None:
143132
raise ValueError("tensorflow_version needs to be passed for building a tensorflow_js model")
144133
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
145-
source=weight_source,
146-
sha256=weight_hash,
147-
tensorflow_version=tensorflow_version,
148-
**attachments,
134+
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
149135
)
150136

151137
elif weight_type in weight_types:
@@ -363,7 +349,7 @@ def get_size(path):
363349
"allow_tiling": True,
364350
"model_keys": None,
365351
}
366-
return {"deepimagej": config}, attachments
352+
return {"deepimagej": config}, [Path(a) for a in attachments]
367353

368354

369355
def _write_sample_data(input_paths, output_paths, input_axes, output_axes, export_folder: Path):
@@ -518,9 +504,8 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
518504
return [_ensure_local_or_url(s, root) for s in source]
519505

520506
local_source = resolve_local_source(source, root)
521-
local_source = resolve_local_source(
522-
local_source, root, None if isinstance(local_source, URI) else root / local_source.name
523-
)
507+
if not isinstance(local_source, URI):
508+
local_source = resolve_local_source(local_source, root, root / local_source.name)
524509
return local_source.relative_to(root)
525510

526511

@@ -653,10 +638,25 @@ def build_model(
653638
Only requred for models with onnx weight format.
654639
weight_kwargs: additional keyword arguments for this weight type.
655640
"""
641+
assert architecture is None or isinstance(architecture, str)
656642
if root is None:
657643
root = "."
658644
root = Path(root)
659645

646+
if attachments is not None:
647+
assert isinstance(attachments, dict)
648+
if "files" in attachments:
649+
afiles = attachments["files"]
650+
if isinstance(afiles, str):
651+
afiles = [afiles]
652+
653+
if isinstance(afiles, list):
654+
afiles = _ensure_local_or_url(afiles, root)
655+
else:
656+
raise TypeError(attachments)
657+
658+
attachments["files"] = afiles
659+
660660
#
661661
# generate the model specific fields
662662
#
@@ -783,7 +783,7 @@ def build_model(
783783
elif "files" not in attachments:
784784
attachments["files"] = ij_attachments
785785
else:
786-
attachments["files"].extend(ij_attachments)
786+
attachments["files"] = list(set(attachments["files"]) | set(ij_attachments))
787787

788788
if links is None:
789789
links = ["deepimagej/deepimagej"]
@@ -803,7 +803,6 @@ def build_model(
803803

804804
# optional kwargs, don't pass them if none
805805
optional_kwargs = {
806-
"attachments": attachments,
807806
"config": config,
808807
"git_repo": git_repo,
809808
"packaged_by": packaged_by,
@@ -814,13 +813,15 @@ def build_model(
814813
}
815814
kwargs = {k: v for k, v in optional_kwargs.items() if v is not None}
816815

816+
if attachments is not None:
817+
kwargs["attachments"] = model_spec.raw_nodes.Attachments(**attachments)
817818
if dependencies is not None:
818819
kwargs["dependencies"] = _get_dependencies(dependencies, root)
820+
if maintainers is not None:
821+
kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers]
819822
if parent is not None:
820823
assert len(parent) == 2
821824
kwargs["parent"] = {"uri": parent[0], "sha256": parent[1]}
822-
if maintainers is not None:
823-
kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers]
824825

825826
try:
826827
model = model_spec.raw_nodes.Model(

tests/build_spec/test_build_spec.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,24 @@ def _test_build_spec(
115115
assert sample.exists()
116116

117117
assert loaded_model.maintainers[0].github_user == "jane_doe"
118+
118119
attachments = loaded_model.attachments
119120
if attachments is not missing and attachments.files is not missing:
120-
for attached_file in attachments["files"]:
121+
for attached_file in attachments.files:
121122
assert attached_file.exists()
122123

124+
# make sure there is one attachment per pre/post-processing
125+
if add_deepimagej_config:
126+
preproc, postproc = preprocessing[0], postprocessing[0]
127+
n_processing = 0
128+
if preproc is not None:
129+
n_processing += len(preproc)
130+
if postproc is not None:
131+
n_processing += len(postproc)
132+
if n_processing > 0:
133+
assert attachments.files is not missing
134+
assert n_processing == len(attachments.files)
135+
123136

124137
def test_build_spec_pytorch(any_torch_model, tmp_path):
125138
_test_build_spec(any_torch_model, tmp_path / "model.zip", "pytorch_state_dict")

0 commit comments

Comments
 (0)