Skip to content

Commit c631317

Browse files
author
Theodoros Katzalis
committed
Fix prediction with blocking
There was an issue during the merging process of the blocks. The `local_slice` needs the halo information and not the padding. Since padding can be applied only if the expansion doesn't include the actual image pixels.
1 parent 66665c5 commit c631317

File tree

2 files changed

+51
-38
lines changed

2 files changed

+51
-38
lines changed

bioimageio/core/block_meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def local_slice(self) -> PerAxis[SliceInfo]:
161161
return Frozen(
162162
{
163163
a: SliceInfo(
164-
self.padding[a].left,
165-
self.padding[a].left + self.inner_shape[a],
164+
self.halo[a].left,
165+
self.halo[a].left + self.inner_shape[a],
166166
)
167167
for a in self.inner_slice
168168
}

example/model_usage.ipynb

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"import os\n",
3535
"\n",
3636
"if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
37-
" %pip install bioimageio.core==0.6.7 torch==2.3.1 onnxruntime==1.18.0"
37+
" %pip install bioimageio.core==0.6.7 torch==2.3.1 onnxruntime==1.18.0"
3838
]
3939
},
4040
{
@@ -55,6 +55,7 @@
5555
"from bioimageio.spec.pretty_validation_errors import (\n",
5656
" enable_pretty_validation_errors_in_ipynb,\n",
5757
")\n",
58+
"\n",
5859
"enable_pretty_validation_errors_in_ipynb()"
5960
]
6061
},
@@ -78,34 +79,33 @@
7879
"import matplotlib.pyplot as plt\n",
7980
"import numpy as np\n",
8081
"\n",
82+
"\n",
8183
"# Function to display input and prediction output images\n",
8284
"def show_images(sample_tensor, prediction_tensor):\n",
83-
" input_array = sample_tensor.members['input0'].data\n",
84-
" \n",
85+
" input_array = sample_tensor.members[\"input0\"].data\n",
86+
"\n",
8587
" # Check for the number of channels to enable display\n",
8688
" input_array = np.squeeze(input_array)\n",
87-
" if len(input_array.shape)>2:\n",
89+
" if len(input_array.shape) > 2:\n",
8890
" input_array = input_array[0]\n",
8991
"\n",
90-
" output_array = prediction_tensor.members['output0'].data\n",
91-
" \n",
92+
" output_array = prediction_tensor.members[\"output0\"].data\n",
93+
"\n",
9294
" # Check for the number of channels to enable display\n",
9395
" output_array = np.squeeze(output_array)\n",
94-
" if len(output_array.shape)>2:\n",
96+
" if len(output_array.shape) > 2:\n",
9597
" output_array = output_array[0]\n",
9698
"\n",
9799
" plt.figure()\n",
98-
" ax1 = plt.subplot(1,2,1)\n",
100+
" ax1 = plt.subplot(1, 2, 1)\n",
99101
" ax1.set_title(\"Input\")\n",
100-
" ax1.axis('off')\n",
102+
" ax1.axis(\"off\")\n",
101103
" plt.imshow(input_array)\n",
102-
" ax2 = plt.subplot(1,2,2)\n",
104+
" ax2 = plt.subplot(1, 2, 2)\n",
103105
" ax2.set_title(\"Prediction\")\n",
104-
" ax2.axis('off')\n",
106+
" ax2.axis(\"off\")\n",
105107
" plt.imshow(output_array)\n",
106-
" plt.show()\n",
107-
" \n",
108-
" "
108+
" plt.show()"
109109
]
110110
},
111111
{
@@ -153,8 +153,8 @@
153153
"metadata": {},
154154
"outputs": [],
155155
"source": [
156-
"BMZ_MODEL_ID = \"\"#\"affable-shark\"\n",
157-
"BMZ_MODEL_DOI = \"\" #\"10.5281/zenodo.6287342\"\n",
156+
"BMZ_MODEL_ID = \"\" # \"affable-shark\"\n",
157+
"BMZ_MODEL_DOI = \"\" # \"10.5281/zenodo.6287342\"\n",
158158
"BMZ_MODEL_URL = \"https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/affable-shark/draft/files/rdf.yaml\""
159159
]
160160
},
@@ -178,26 +178,33 @@
178178
"# Load the model description\n",
179179
"# ------------------------------------------------------------------------------\n",
180180
"if BMZ_MODEL_ID != \"\":\n",
181-
" model = load_description(BMZ_MODEL_ID) \n",
182-
" print(f\"\\nThe model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.\")\n",
181+
" model = load_description(BMZ_MODEL_ID)\n",
182+
" print(\n",
183+
" f\"\\nThe model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.\"\n",
184+
" )\n",
183185
"elif BMZ_MODEL_DOI != \"\":\n",
184-
" model = load_description(BMZ_MODEL_DOI) \n",
185-
" print(f\"\\nThe model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.\")\n",
186+
" model = load_description(BMZ_MODEL_DOI)\n",
187+
" print(\n",
188+
" f\"\\nThe model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.\"\n",
189+
" )\n",
186190
"elif BMZ_MODEL_URL != \"\":\n",
187-
" model = load_description(BMZ_MODEL_URL) \n",
188-
" print(f\"\\nThe model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.\")\n",
191+
" model = load_description(BMZ_MODEL_URL)\n",
192+
" print(\n",
193+
" f\"\\nThe model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.\"\n",
194+
" )\n",
189195
"else:\n",
190-
" print('\\nPlease specify a model ID, DOI or URL')\n",
196+
" print(\"\\nPlease specify a model ID, DOI or URL\")\n",
191197
"\n",
192198
"if \"draft\" in BMZ_MODEL_ID or \"draft\" in BMZ_MODEL_DOI or \"draft\" in BMZ_MODEL_URL:\n",
193-
" print(f\"\\nThis is the DRAFT version of '{model.name}'. \\nDraft versions have not been reviewed by the Bioimage Model Zoo Team and may contain harmful code. Run with caution.\")\n",
199+
" print(\n",
200+
" f\"\\nThis is the DRAFT version of '{model.name}'. \\nDraft versions have not been reviewed by the Bioimage Model Zoo Team and may contain harmful code. Run with caution.\"\n",
201+
" )\n",
194202
"\n",
195203
"# To be added later:\n",
196204
"# elif model.version != model.lastest_version:\n",
197205
"# print('\\nThe loaded version of the model is: ' + model.version, 'but the latest version of the model is: ' + model.lastest_version)\n",
198206
"\n",
199-
"# TODO: on the model loading success responses add version loaded\n",
200-
"\n"
207+
"# TODO: on the model loading success responses add version loaded"
201208
]
202209
},
203210
{
@@ -216,7 +223,7 @@
216223
"outputs": [],
217224
"source": [
218225
"print(f\"The model '{model.name}' has the following properties and metadata\\n\")\n",
219-
"print(f\" Description:\") \n",
226+
"print(f\" Description:\")\n",
220227
"pprint(model.description)\n",
221228
"\n",
222229
"print(\"\\n The authors of the model are: \")\n",
@@ -246,7 +253,7 @@
246253
" plt.imshow(cover_data)\n",
247254
" plt.xticks([])\n",
248255
" plt.yticks([])\n",
249-
" plt.show()\n"
256+
" plt.show()"
250257
]
251258
},
252259
{
@@ -269,7 +276,9 @@
269276
"metadata": {},
270277
"outputs": [],
271278
"source": [
272-
"print(f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\")\n",
279+
"print(\n",
280+
" f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\"\n",
281+
")\n",
273282
"for ipt in model.inputs:\n",
274283
" print(f\"\\ninput '{ipt.id}' with axes:\")\n",
275284
" pprint(ipt.axes)\n",
@@ -280,9 +289,13 @@
280289
" for p in ipt.preprocessing:\n",
281290
" print(p)\n",
282291
"\n",
283-
"print(\"\\n-------------------------------------------------------------------------------\")\n",
292+
"print(\n",
293+
" \"\\n-------------------------------------------------------------------------------\"\n",
294+
")\n",
284295
"# # and what the model outputs are\n",
285-
"print(f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\")\n",
296+
"print(\n",
297+
" f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\"\n",
298+
")\n",
286299
"for out in model.outputs:\n",
287300
" print(f\"\\noutput '{out.id}' with axes:\")\n",
288301
" pprint(out.axes)\n",
@@ -572,9 +585,9 @@
572585
"\n",
573586
" # Check for the number of channels to enable display\n",
574587
" input_array = np.squeeze(input_array)\n",
575-
" if len(input_array.shape)>2:\n",
588+
" if len(input_array.shape) > 2:\n",
576589
" input_array = input_array[0]\n",
577-
" \n",
590+
"\n",
578591
" np_input_list.append(input_array)\n",
579592
"\n",
580593
"\n",
@@ -584,9 +597,9 @@
584597
"\n",
585598
" # Check for the number of channels to enable display\n",
586599
" output_array = np.squeeze(output_array)\n",
587-
" if len(output_array.shape)>2:\n",
600+
" if len(output_array.shape) > 2:\n",
588601
" output_array = output_array[0]\n",
589-
" \n",
602+
"\n",
590603
" np_output_list.append(output_array)\n",
591604
"\n",
592605
"plt.imshow(np_input_list[0])"
@@ -609,7 +622,7 @@
609622
"name": "python",
610623
"nbconvert_exporter": "python",
611624
"pygments_lexer": "ipython3",
612-
"version": "3.11.9"
625+
"version": "3.9.19"
613626
}
614627
},
615628
"nbformat": 4,

0 commit comments

Comments
 (0)