Skip to content

Commit 963b445

Browse files
Clean up weights in cwd in add_weights
1 parent 753f3e8 commit 963b445

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

bioimageio/core/build_spec/add_weights.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def add_weights(
6363
except Exception as e:
6464
raise e
6565
finally:
66+
# clean up tmp files
67+
os.remove(weight_out)
6668
if tmp_arch is not None:
6769
os.remove(tmp_arch)
70+
# for some reason the weights are also copied to the cwd.
71+
# not sure why this happens, but it needs to be cleaned up, unless these are the input weigths
72+
weights_cwd = Path(os.path.split(weight_uri)[1])
73+
if weights_cwd.exists() and weights_cwd.absolute() != Path(weight_uri).absolute():
74+
os.remove(weights_cwd)
6875
return model

tests/build_spec/test_add_weights.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description
23
from bioimageio.core.resource_tests import test_model as _test_model
34

@@ -33,6 +34,9 @@ def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs):
3334
test_res = _test_model(out_path)
3435
assert test_res["error"] is None
3536

37+
# make sure the weights were cleaned from the cwd
38+
assert not os.path.exists(os.path.split(weight_path)[1])
39+
3640

3741
def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path):
3842
_test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript")

0 commit comments

Comments
 (0)