Skip to content

Commit 3364875

Browse files
Merge pull request #233 from bioimage-io/add-training-data
Add training data
2 parents 504d845 + b92e19e commit 3364875

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ def build_model(
636636
config: Optional[Dict[str, Any]] = None,
637637
dependencies: Optional[Union[Path, str]] = None,
638638
links: Optional[List[str]] = None,
639+
training_data: Optional[Dict[str, str]] = None,
639640
root: Optional[Union[Path, str]] = None,
640641
add_deepimagej_config: bool = False,
641642
tensorflow_version: Optional[str] = None,
@@ -711,6 +712,7 @@ def build_model(
711712
parent: id of the parent model from which this model is derived and sha256 of the corresponding rdf file.
712713
config: custom configuration for this model.
713714
dependencies: relative path to file with dependencies for this model.
715+
training_data: the training data for this model, either id for a bioimageio dataset or a dataset spec.
714716
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
715717
add_deepimagej_config: add the deepimagej config to the model.
716718
tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
@@ -887,10 +889,22 @@ def build_model(
887889

888890
if maintainers is not None:
889891
kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers]
892+
890893
if parent is not None:
891894
assert len(parent) == 2
892895
kwargs["parent"] = parent
893896

897+
if training_data is not None:
898+
if "id" in training_data:
899+
msg = f"If training data is specified via 'id' no other keys are allowed, got {training_data}"
900+
assert len(training_data) == 1, msg
901+
kwargs["training_data"] = training_data
902+
else:
903+
if "type" not in training_data:
904+
training_data["type"] = "dataset"
905+
if "format_version" not in training_data:
906+
training_data["format_version"] = spec.dataset.format_version
907+
894908
try:
895909
model = model_spec.raw_nodes.Model(
896910
authors=authors,

example/bioimageio-core-usage.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@
411411
"# https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/build_spec/build_model.py#L252\n",
412412
"cite = [{\"text\": cite_entry.text, \"url\": cite_entry.url} for cite_entry in model_resource.cite]\n",
413413
"\n",
414+
"# the training data used for the model can also be specified by linking to a dataset available on bioimage.io\n",
415+
"training_data = {\"id\": \"ilastik/stradist_dsb_training_data\"}\n",
416+
"\n",
414417
"# the axes descriptions for the inputs / outputs\n",
415418
"input_axes = [\"bcyx\"]\n",
416419
"output_axes = [\"bcyx\"]\n",
@@ -441,7 +444,8 @@
441444
" architecture=model_source,\n",
442445
" model_kwargs=model_resource.weights[\"pytorch_state_dict\"].kwargs,\n",
443446
" preprocessing=preprocessing,\n",
444-
" postprocessing=postprocessing\n",
447+
" postprocessing=postprocessing,\n",
448+
" training_data=training_data,\n",
445449
")"
446450
]
447451
},

tests/build_spec/test_build_spec.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def _test_build_spec(
1717
add_deepimagej_config=False,
1818
use_original_covers=False,
1919
use_absoloute_arch_path=False,
20+
training_data=None,
2021
):
2122
from bioimageio.core.build_spec import build_model
2223

@@ -115,6 +116,8 @@ def _test_build_spec(
115116
kwargs["pixel_sizes"] = [{"x": 5.0, "y": 5.0}]
116117
if use_original_covers:
117118
kwargs["covers"] = resolve_source(model_spec.covers, root)
119+
if training_data is not None:
120+
kwargs["training_data"] = training_data
118121

119122
build_model(**kwargs)
120123
assert out_path.exists()
@@ -193,6 +196,21 @@ def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
193196
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", add_deepimagej_config=True)
194197

195198

199+
def test_build_spec_training_data1(unet2d_nuclei_broad_model, tmp_path):
200+
training_data = {"id": "ilastik/stradist_dsb_training_data"}
201+
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data)
202+
203+
204+
def test_build_spec_training_data2(unet2d_nuclei_broad_model, tmp_path):
205+
training_data = {
206+
"type": "dataset",
207+
"name": "nucleus-training-data",
208+
"description": "stardist nucleus training data",
209+
"source": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip",
210+
}
211+
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data)
212+
213+
196214
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
197215
_test_build_spec(
198216
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version="1.12"

0 commit comments

Comments
 (0)