Skip to content

Commit 18ab64c

Browse files
committed
use _ensure_local for arch_file
1 parent 3134dc7 commit 18ab64c

File tree

1 file changed

+10
-25
lines changed

1 file changed

+10
-25
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,10 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
6161
tmp_archtecture = None
6262
weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {}
6363
if ":" in architecture:
64-
arch_file, callable_name = architecture.replace("::", ":").split(":")
65-
66-
# this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
67-
if os.path.isabs(arch_file):
68-
tmp_archtecture = Path("this_model_architecture.py")
69-
copyfile(arch_file, root / tmp_archtecture)
70-
arch = ImportableSourceFile(callable_name, tmp_archtecture)
71-
else:
72-
arch = ImportableSourceFile(callable_name, Path(arch_file))
73-
64+
# note: path itself might include : for absolute paths in windows
65+
*arch_file_parts, callable_name = architecture.replace("::", ":").split(":")
66+
arch_file = _ensure_local(":".join(arch_file_parts), root)
67+
arch = ImportableSourceFile(callable_name, arch_file)
7468
arch_hash = _get_hash(root / arch.source_file)
7569
weight_kwargs["architecture_sha256"] = arch_hash
7670
else:
@@ -123,30 +117,21 @@ def _get_weights(
123117
if tensorflow_version is None:
124118
raise ValueError("tensorflow_version needs to be passed for building a keras model")
125119
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
126-
source=weight_source,
127-
sha256=weight_hash,
128-
tensorflow_version=tensorflow_version,
129-
**attachments,
120+
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
130121
)
131122

132123
elif weight_type == "tensorflow_saved_model_bundle":
133124
if tensorflow_version is None:
134125
raise ValueError("tensorflow_version needs to be passed for building a tensorflow model")
135126
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
136-
source=weight_source,
137-
sha256=weight_hash,
138-
tensorflow_version=tensorflow_version,
139-
**attachments,
127+
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
140128
)
141129

142130
elif weight_type == "tensorflow_js":
143131
if tensorflow_version is None:
144132
raise ValueError("tensorflow_version needs to be passed for building a tensorflow_js model")
145133
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
146-
source=weight_source,
147-
sha256=weight_hash,
148-
tensorflow_version=tensorflow_version,
149-
**attachments,
134+
source=weight_source, sha256=weight_hash, tensorflow_version=tensorflow_version, **attachments
150135
)
151136

152137
elif weight_type in weight_types:
@@ -519,9 +504,8 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
519504
return [_ensure_local_or_url(s, root) for s in source]
520505

521506
local_source = resolve_local_source(source, root)
522-
local_source = resolve_local_source(
523-
local_source, root, None if isinstance(local_source, URI) else root / local_source.name
524-
)
507+
if not isinstance(local_source, URI):
508+
local_source = resolve_local_source(local_source, root, root / local_source.name)
525509
return local_source.relative_to(root)
526510

527511

@@ -654,6 +638,7 @@ def build_model(
654638
Only requred for models with onnx weight format.
655639
weight_kwargs: additional keyword arguments for this weight type.
656640
"""
641+
assert architecture is None or isinstance(architecture, str)
657642
if root is None:
658643
root = "."
659644
root = Path(root)

0 commit comments

Comments
 (0)