|
34 | 34 | "import os\n",
|
35 | 35 | "\n",
|
36 | 36 | "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
|
37 |
| - " %pip install bioimageio.core==0.6.7 torch==2.3.1 onnxruntime==1.18.0" |
| 37 | + " %pip install bioimageio.core==0.6.7 torch==2.3.1 onnxruntime==1.18.0" |
38 | 38 | ]
|
39 | 39 | },
|
40 | 40 | {
|
|
55 | 55 | "from bioimageio.spec.pretty_validation_errors import (\n",
|
56 | 56 | " enable_pretty_validation_errors_in_ipynb,\n",
|
57 | 57 | ")\n",
|
| 58 | + "\n", |
58 | 59 | "enable_pretty_validation_errors_in_ipynb()"
|
59 | 60 | ]
|
60 | 61 | },
|
|
78 | 79 | "import matplotlib.pyplot as plt\n",
|
79 | 80 | "import numpy as np\n",
|
80 | 81 | "\n",
|
| 82 | + "\n", |
81 | 83 | "# Function to display input and prediction output images\n",
|
82 | 84 | "def show_images(sample_tensor, prediction_tensor):\n",
|
83 |
| - " input_array = sample_tensor.members['input0'].data\n", |
84 |
| - " \n", |
| 85 | + " input_array = sample_tensor.members[\"input0\"].data\n", |
| 86 | + "\n", |
85 | 87 | " # Check for the number of channels to enable display\n",
|
86 | 88 | " input_array = np.squeeze(input_array)\n",
|
87 |
| - " if len(input_array.shape)>2:\n", |
| 89 | + " if len(input_array.shape) > 2:\n", |
88 | 90 | " input_array = input_array[0]\n",
|
89 | 91 | "\n",
|
90 |
| - " output_array = prediction_tensor.members['output0'].data\n", |
91 |
| - " \n", |
| 92 | + " output_array = prediction_tensor.members[\"output0\"].data\n", |
| 93 | + "\n", |
92 | 94 | " # Check for the number of channels to enable display\n",
|
93 | 95 | " output_array = np.squeeze(output_array)\n",
|
94 |
| - " if len(output_array.shape)>2:\n", |
| 96 | + " if len(output_array.shape) > 2:\n", |
95 | 97 | " output_array = output_array[0]\n",
|
96 | 98 | "\n",
|
97 | 99 | " plt.figure()\n",
|
98 |
| - " ax1 = plt.subplot(1,2,1)\n", |
| 100 | + " ax1 = plt.subplot(1, 2, 1)\n", |
99 | 101 | " ax1.set_title(\"Input\")\n",
|
100 |
| - " ax1.axis('off')\n", |
| 102 | + " ax1.axis(\"off\")\n", |
101 | 103 | " plt.imshow(input_array)\n",
|
102 |
| - " ax2 = plt.subplot(1,2,2)\n", |
| 104 | + " ax2 = plt.subplot(1, 2, 2)\n", |
103 | 105 | " ax2.set_title(\"Prediction\")\n",
|
104 |
| - " ax2.axis('off')\n", |
| 106 | + " ax2.axis(\"off\")\n", |
105 | 107 | " plt.imshow(output_array)\n",
|
106 |
| - " plt.show()\n", |
107 |
| - " \n", |
108 |
| - " " |
| 108 | + " plt.show()" |
109 | 109 | ]
|
110 | 110 | },
|
111 | 111 | {
|
|
153 | 153 | "metadata": {},
|
154 | 154 | "outputs": [],
|
155 | 155 | "source": [
|
156 |
| - "BMZ_MODEL_ID = \"\"#\"affable-shark\"\n", |
157 |
| - "BMZ_MODEL_DOI = \"\" #\"10.5281/zenodo.6287342\"\n", |
| 156 | + "BMZ_MODEL_ID = \"\" # \"affable-shark\"\n", |
| 157 | + "BMZ_MODEL_DOI = \"\" # \"10.5281/zenodo.6287342\"\n", |
158 | 158 | "BMZ_MODEL_URL = \"https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/affable-shark/draft/files/rdf.yaml\""
|
159 | 159 | ]
|
160 | 160 | },
|
|
178 | 178 | "# Load the model description\n",
|
179 | 179 | "# ------------------------------------------------------------------------------\n",
|
180 | 180 | "if BMZ_MODEL_ID != \"\":\n",
|
181 |
| - " model = load_description(BMZ_MODEL_ID) \n", |
182 |
| - " print(f\"\\nThe model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.\")\n", |
| 181 | + " model = load_description(BMZ_MODEL_ID)\n", |
| 182 | + " print(\n", |
| 183 | + " f\"\\nThe model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.\"\n", |
| 184 | + " )\n", |
183 | 185 | "elif BMZ_MODEL_DOI != \"\":\n",
|
184 |
| - " model = load_description(BMZ_MODEL_DOI) \n", |
185 |
| - " print(f\"\\nThe model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.\")\n", |
| 186 | + " model = load_description(BMZ_MODEL_DOI)\n", |
| 187 | + " print(\n", |
| 188 | + " f\"\\nThe model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.\"\n", |
| 189 | + " )\n", |
186 | 190 | "elif BMZ_MODEL_URL != \"\":\n",
|
187 |
| - " model = load_description(BMZ_MODEL_URL) \n", |
188 |
| - " print(f\"\\nThe model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.\")\n", |
| 191 | + " model = load_description(BMZ_MODEL_URL)\n", |
| 192 | + " print(\n", |
| 193 | + " f\"\\nThe model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.\"\n", |
| 194 | + " )\n", |
189 | 195 | "else:\n",
|
190 |
| - " print('\\nPlease specify a model ID, DOI or URL')\n", |
| 196 | + " print(\"\\nPlease specify a model ID, DOI or URL\")\n", |
191 | 197 | "\n",
|
192 | 198 | "if \"draft\" in BMZ_MODEL_ID or \"draft\" in BMZ_MODEL_DOI or \"draft\" in BMZ_MODEL_URL:\n",
|
193 |
| - " print(f\"\\nThis is the DRAFT version of '{model.name}'. \\nDraft versions have not been reviewed by the Bioimage Model Zoo Team and may contain harmful code. Run with caution.\")\n", |
| 199 | + " print(\n", |
| 200 | + " f\"\\nThis is the DRAFT version of '{model.name}'. \\nDraft versions have not been reviewed by the Bioimage Model Zoo Team and may contain harmful code. Run with caution.\"\n", |
| 201 | + " )\n", |
194 | 202 | "\n",
|
195 | 203 | "# To be added later:\n",
|
196 | 204 | "# elif model.version != model.lastest_version:\n",
|
197 | 205 | "# print('\\nThe loaded version of the model is: ' + model.version, 'but the latest version of the model is: ' + model.lastest_version)\n",
|
198 | 206 | "\n",
|
199 |
| - "# TODO: on the model loading success responses add version loaded\n", |
200 |
| - "\n" |
| 207 | + "# TODO: on the model loading success responses add version loaded" |
201 | 208 | ]
|
202 | 209 | },
|
203 | 210 | {
|
|
216 | 223 | "outputs": [],
|
217 | 224 | "source": [
|
218 | 225 | "print(f\"The model '{model.name}' has the following properties and metadata\\n\")\n",
|
219 |
| - "print(f\" Description:\") \n", |
| 226 | + "print(f\" Description:\")\n", |
220 | 227 | "pprint(model.description)\n",
|
221 | 228 | "\n",
|
222 | 229 | "print(\"\\n The authors of the model are: \")\n",
|
|
246 | 253 | " plt.imshow(cover_data)\n",
|
247 | 254 | " plt.xticks([])\n",
|
248 | 255 | " plt.yticks([])\n",
|
249 |
| - " plt.show()\n" |
| 256 | + " plt.show()" |
250 | 257 | ]
|
251 | 258 | },
|
252 | 259 | {
|
|
269 | 276 | "metadata": {},
|
270 | 277 | "outputs": [],
|
271 | 278 | "source": [
|
272 |
| - "print(f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\")\n", |
| 279 | + "print(\n", |
| 280 | + " f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\"\n", |
| 281 | + ")\n", |
273 | 282 | "for ipt in model.inputs:\n",
|
274 | 283 | " print(f\"\\ninput '{ipt.id}' with axes:\")\n",
|
275 | 284 | " pprint(ipt.axes)\n",
|
|
280 | 289 | " for p in ipt.preprocessing:\n",
|
281 | 290 | " print(p)\n",
|
282 | 291 | "\n",
|
283 |
| - "print(\"\\n-------------------------------------------------------------------------------\")\n", |
| 292 | + "print(\n", |
| 293 | + " \"\\n-------------------------------------------------------------------------------\"\n", |
| 294 | + ")\n", |
284 | 295 | "# # and what the model outputs are\n",
|
285 |
| - "print(f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\")\n", |
| 296 | + "print(\n", |
| 297 | + " f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\"\n", |
| 298 | + ")\n", |
286 | 299 | "for out in model.outputs:\n",
|
287 | 300 | " print(f\"\\noutput '{out.id}' with axes:\")\n",
|
288 | 301 | " pprint(out.axes)\n",
|
|
572 | 585 | "\n",
|
573 | 586 | " # Check for the number of channels to enable display\n",
|
574 | 587 | " input_array = np.squeeze(input_array)\n",
|
575 |
| - " if len(input_array.shape)>2:\n", |
| 588 | + " if len(input_array.shape) > 2:\n", |
576 | 589 | " input_array = input_array[0]\n",
|
577 |
| - " \n", |
| 590 | + "\n", |
578 | 591 | " np_input_list.append(input_array)\n",
|
579 | 592 | "\n",
|
580 | 593 | "\n",
|
|
584 | 597 | "\n",
|
585 | 598 | " # Check for the number of channels to enable display\n",
|
586 | 599 | " output_array = np.squeeze(output_array)\n",
|
587 |
| - " if len(output_array.shape)>2:\n", |
| 600 | + " if len(output_array.shape) > 2:\n", |
588 | 601 | " output_array = output_array[0]\n",
|
589 |
| - " \n", |
| 602 | + "\n", |
590 | 603 | " np_output_list.append(output_array)\n",
|
591 | 604 | "\n",
|
592 | 605 | "plt.imshow(np_input_list[0])"
|
|
609 | 622 | "name": "python",
|
610 | 623 | "nbconvert_exporter": "python",
|
611 | 624 | "pygments_lexer": "ipython3",
|
612 |
| - "version": "3.11.9" |
| 625 | + "version": "3.9.19" |
613 | 626 | }
|
614 | 627 | },
|
615 | 628 | "nbformat": 4,
|
|
0 commit comments