diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index c890e8df..8ce50a1a 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -8,7 +8,7 @@ on:
defaults:
run:
- shell: micromamba-shell {0}
+ shell: bash -el {0}
jobs:
black:
@@ -22,215 +22,218 @@ jobs:
jupyter: true
version: "24.3"
- test-spec-conda:
+ populate-cache:
runs-on: ubuntu-latest
- strategy:
- matrix:
- python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
+ outputs:
+ cache-key: ${{steps.cache-key.outputs.cache-key}}
steps:
- - uses: actions/checkout@v4
- - name: Install Conda environment with Micromamba
- if: matrix.python-version != '3.8'
- uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-file: dev/env-wo-python.yaml
- create-args: >-
- python=${{ matrix.python-version }}
- post-cleanup: 'all'
- - name: Install py3.8 environment
- if: matrix.python-version == '3.8'
- uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-file: dev/env-py38.yaml
- post-cleanup: 'all'
- - name: additional setup
- run: pip install --no-deps -e .
- - name: Get Date
- id: get-date
- run: |
- echo "date=$(date +'%Y-%b')"
- echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT
- shell: bash
- - uses: actions/cache@v4
- with:
- path: bioimageio_cache
- key: "test-spec-conda-${{ steps.get-date.outputs.date }}"
- - name: pytest-spec-conda
- run: pytest --disable-pytest-warnings
- env:
- BIOIMAGEIO_CACHE_PATH: bioimageio_cache
-
- test-spec-main:
+ - name: Get Date
+ id: get-date
+ run: |
+ echo "date=$(date +'%Y-%b')"
+ echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT
+ - id: cache-key
+ run: echo "cache-key=test-${{steps.get-date.outputs.date}}" >> $GITHUB_OUTPUT
+ - uses: actions/cache/restore@v4
+ id: look-up
+ with:
+ path: bioimageio_cache
+ key: ${{steps.cache-key.outputs.cache-key}}
+ lookup-only: true
+ - uses: actions/checkout@v4
+ if: steps.look-up.outputs.cache-hit != 'true'
+ - uses: actions/cache@v4
+ if: steps.look-up.outputs.cache-hit != 'true'
+ with:
+ path: bioimageio_cache
+ key: ${{steps.cache-key.outputs.cache-key}}
+ - uses: actions/setup-python@v5
+ if: steps.look-up.outputs.cache-hit != 'true'
+ with:
+ python-version: '3.12'
+ cache: 'pip'
+ - name: Install dependencies
+ if: steps.look-up.outputs.cache-hit != 'true'
+ run: |
+ pip install --upgrade pip
+ pip install -e .[dev]
+ - run: pytest --disable-pytest-warnings tests/test_bioimageio_collection.py::test_rdf_format_to_populate_cache
+ if: steps.look-up.outputs.cache-hit != 'true'
+ env:
+ BIOIMAGEIO_POPULATE_CACHE: '1'
+ BIOIMAGEIO_CACHE_PATH: bioimageio_cache
+ test:
+ needs: populate-cache
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ['3.8', '3.12']
include:
+ - python-version: '3.8'
+ conda-env: py38
+ spec: conda
+ - python-version: '3.8'
+ conda-env: py38
+ spec: main
+ - python-version: '3.9'
+ conda-env: dev
+ spec: conda
+ - python-version: '3.10'
+ conda-env: dev
+ spec: conda
+ - python-version: '3.11'
+ conda-env: full
+ spec: main
+ run-expensive-tests: true
+ report-coverage: true
+ save-cache: true
- python-version: '3.12'
- is-dev-version: true
+ conda-env: dev
+ spec: conda
+ - python-version: '3.13'
+ conda-env: dev
+ spec: main
+ save-cache: true
+
steps:
- uses: actions/checkout@v4
- - name: Install Conda environment with Micromamba
- if: matrix.python-version != '3.8'
- uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-file: dev/env-wo-python.yaml
- create-args: >-
- python=${{ matrix.python-version }}
- post-cleanup: 'all'
- - name: Install py3.8 environment
- if: matrix.python-version == '3.8'
- uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-file: dev/env-py38.yaml
- post-cleanup: 'all'
- - name: additional setup spec
+ - id: setup
+ run: |
+ echo "env-name=${{ matrix.spec }}-${{ matrix.conda-env }}-${{ matrix.python-version }}"
+ echo "env-name=${{ matrix.spec }}-${{ matrix.conda-env }}-${{ matrix.python-version }}" >> $GITHUB_OUTPUT
+ echo "env-file=dev/env-${{ matrix.conda-env }}.yaml"
+ echo "env-file=dev/env-${{ matrix.conda-env }}.yaml" >> $GITHUB_OUTPUT
+ - name: check on env-file
+ shell: python
run: |
- conda remove --yes --force bioimageio.spec || true # allow failure for cached env
- pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io
- - name: additional setup core
- run: pip install --no-deps -e .
+ from pathlib import Path
+ from pprint import pprint
+ if not (env_path:=Path("${{steps.setup.outputs.env-file}}")).exists():
+ if env_path.parent.exists():
+ pprint(env_path.parent.glob("*"))
+ else:
+ pprint(Path().glob("*"))
+ raise FileNotFoundError(f"{env_path} does not exist")
+
+ - uses: conda-incubator/setup-miniconda@v3
+ with:
+ auto-update-conda: true
+ auto-activate-base: true
+ activate-environment: ${{steps.setup.outputs.env-name}}
+ channel-priority: strict
+ miniforge-version: latest
- name: Get Date
id: get-date
run: |
- echo "date=$(date +'%Y-%b')"
- echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT
- shell: bash
- - uses: actions/cache@v4
+ echo "today=$(date -u '+%Y%m%d')"
+ echo "today=$(date -u '+%Y%m%d')" >> $GITHUB_OUTPUT
+ - name: Restore cached env
+ uses: actions/cache/restore@v4
+ with:
+ path: ${{env.CONDA}}/envs/${{steps.setup.outputs.env-name}}
+ key: >-
+ conda-${{runner.os}}-${{runner.arch}}
+ -${{steps.get-date.outputs.today}}
+ -${{hashFiles(steps.setup.outputs.env-file)}}
+ -${{env.CACHE_NUMBER}}
+ env:
+ CACHE_NUMBER: 0
+ id: cache-env
+ - name: Install env
+ run: conda env update --name=${{steps.setup.outputs.env-name}} --file=${{steps.setup.outputs.env-file}} python=${{matrix.python-version}}
+ if: steps.cache-env.outputs.cache-hit != 'true'
+ - name: Install uncached pip dependencies
+ run: |
+ pip install --upgrade pip
+ pip install --no-deps -e .
+ - name: Install uncached pip dependencies for 'full' environment
+ if: matrix.conda-env == 'full'
+ run: |
+ pip install git+https://github.com/ChaoningZhang/MobileSAM.git
+ - name: Cache env
+ if: steps.cache-env.outputs.cache-hit != 'true'
+ uses: actions/cache/save@v4
+ with:
+ path: ${{env.CONDA}}/envs/${{steps.setup.outputs.env-name}}
+ key: >-
+ conda-${{runner.os}}-${{runner.arch}}
+ -${{steps.get-date.outputs.today}}
+ -${{hashFiles(steps.setup.outputs.env-file)}}
+ -${{env.CACHE_NUMBER}}
+ env:
+ CACHE_NUMBER: 0
+ - run: conda list
+ - name: Pyright
+ run: |
+ pyright --version
+ pyright -p pyproject.toml --pythonversion ${{ matrix.python-version }}
+ if: matrix.run-expensive-tests
+ - name: Restore bioimageio cache
+ uses: actions/cache/restore@v4
+ id: bioimageio-cache
with:
path: bioimageio_cache
- key: "test-spec-main-${{ steps.get-date.outputs.date }}"
- - name: pytest-spec-main
+ key: ${{needs.populate-cache.outputs.cache-key}}${{matrix.run-expensive-tests && '' || '-light'}}
+ - name: pytest
run: pytest --disable-pytest-warnings
env:
BIOIMAGEIO_CACHE_PATH: bioimageio_cache
- - if: matrix.is-dev-version && github.event_name == 'pull_request'
+ RUN_EXPENSIVE_TESTS: ${{ matrix.run-expensive-tests && 'true' || 'false' }}
+ - name: Save updated bioimageio cache
+ if: matrix.save-cache
+ uses: actions/cache/save@v4
+ with:
+ path: bioimageio_cache
+ key: ${{needs.populate-cache.outputs.cache-key}}${{matrix.run-expensive-tests && '' || '-light'}}
+ - if: matrix.report-coverage && github.event_name == 'pull_request'
uses: orgoro/coverage@v3.2
with:
coverageFile: coverage.xml
- token: ${{ secrets.GITHUB_TOKEN }}
- - if: matrix.is-dev-version && github.ref == 'refs/heads/main'
+ token: ${{secrets.GITHUB_TOKEN}}
+ - if: matrix.report-coverage && github.ref == 'refs/heads/main'
run: |
pip install genbadge[coverage]
genbadge coverage --input-file coverage.xml --output-file ./dist/coverage/coverage-badge.svg
coverage html -d dist/coverage
- - if: matrix.is-dev-version && github.ref == 'refs/heads/main'
+ - if: matrix.report-coverage && github.ref == 'refs/heads/main'
uses: actions/upload-artifact@v4
with:
name: coverage
retention-days: 1
path: dist
-
- test-spec-main-tf:
- runs-on: ubuntu-latest
- strategy:
- matrix:
- python-version: ['3.9', '3.12']
- steps:
- - uses: actions/checkout@v4
- - uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-file: dev/env-tf.yaml
- condarc: |
- channel-priority: flexible
- create-args: >-
- python=${{ matrix.python-version }}
- post-cleanup: 'all'
- - name: additional setup spec
- run: |
- conda remove --yes --force bioimageio.spec || true # allow failure for cached env
- pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io
- - name: additional setup core
- run: pip install --no-deps -e .
- - name: Get Date
- id: get-date
- run: |
- echo "date=$(date +'%Y-%b')"
- echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT
- shell: bash
- - uses: actions/cache@v4
- with:
- path: bioimageio_cache
- key: "test-spec-main-tf-${{ steps.get-date.outputs.date }}"
- - run: pytest --disable-pytest-warnings
- env:
- BIOIMAGEIO_CACHE_PATH: bioimageio_cache
-
- test-spec-conda-tf:
- runs-on: ubuntu-latest
- strategy:
- matrix:
- python-version: ['3.9', '3.12']
- steps:
- - uses: actions/checkout@v4
- - uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-file: dev/env-tf.yaml
- condarc: |
- channel-priority: flexible
- create-args: >-
- python=${{ matrix.python-version }}
- post-cleanup: 'all'
- - name: additional setup
- run: pip install --no-deps -e .
- - name: Get Date
- id: get-date
- run: |
- echo "date=$(date +'%Y-%b')"
- echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT
- shell: bash
- - uses: actions/cache@v4
- with:
- path: bioimageio_cache
- key: "test-spec-conda-tf-${{ steps.get-date.outputs.date }}"
- - name: pytest-spec-tf
- run: pytest --disable-pytest-warnings
- env:
- BIOIMAGEIO_CACHE_PATH: bioimageio_cache
-
conda-build:
+ needs: test
runs-on: ubuntu-latest
- needs: test-spec-conda
steps:
- - name: checkout
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- - name: Install Conda environment with Micromamba
- uses: mamba-org/setup-micromamba@v1
- with:
- cache-downloads: true
- cache-environment: true
- environment-name: build-env
- condarc: |
- channels:
- - conda-forge
- create-args: |
- boa
- - name: linux conda build
- run: |
- conda mambabuild -c conda-forge conda-recipe
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - uses: conda-incubator/setup-miniconda@v3
+ with:
+ auto-update-conda: true
+ auto-activate-base: true
+ activate-environment: ""
+ channel-priority: strict
+ miniforge-version: latest
+ conda-solver: libmamba
+ - name: install common conda dependencies
+ run: conda install -n base -c conda-forge conda-build -y
+ - uses: actions/cache@v4
+ with:
+ path: |
+ pkgs/noarch
+ pkgs/channeldata.json
+ key: ${{ github.sha }}-packages
+ - name: linux conda build test
+ shell: bash -l {0}
+ run: |
+ mkdir -p ./pkgs/noarch
+ conda-build -c conda-forge conda-recipe --no-test --output-folder ./pkgs
docs:
- needs: [test-spec-main]
+ needs: test
if: github.ref == 'refs/heads/main'
runs-on: ubuntu-latest
- defaults:
- run:
- shell: bash -l {0}
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
diff --git a/.gitignore b/.gitignore
index d8be60be..688e4a88 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,9 +6,12 @@ __pycache__/
*.egg-info/
*.pyc
**/tmp
+bioimageio_unzipped_tf_weights/
build/
cache
coverage.xml
dist/
docs/
+dogfood/
typings/pooch/
+bioimageio_cache/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ef0eba58..2bee435e 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,7 +7,7 @@ repos:
rev: v0.3.2
hooks:
- id: ruff
- args: [--fix]
+ args: [--fix, --show-fixes]
- repo: local
hooks:
- id: pyright
diff --git a/README.md b/README.md
index 92bcb9b0..1a1aa932 100644
--- a/README.md
+++ b/README.md
@@ -275,6 +275,7 @@ functionality, but not any functionality depending on model prediction.
To install additional deep learning libraries add `pytorch`, `onnxruntime`, `keras` or `tensorflow`.
Deeplearning frameworks to consider installing alongside `bioimageio.core`:
+
- [Pytorch/Torchscript](https://pytorch.org/get-started/locally/)
- [TensorFlow](https://www.tensorflow.org/install)
- [ONNXRuntime](https://onnxruntime.ai/docs/install/#python-installs)
@@ -297,13 +298,16 @@ These models are described by---and can be loaded with---the bioimageio.spec pac
In addition bioimageio.core provides functionality to convert model weight formats.
### Documentation
+
[Here you find the bioimageio.core documentation.](https://bioimage-io.github.io/core-bioimage-io-python/bioimageio/core.html)
#### Presentations
+
- [Create a model from scratch](https://bioimage-io.github.io/core-bioimage-io-python/presentations/create_ambitious_sloth.slides.html) ([source](https://github.com/bioimage-io/core-bioimage-io-python/tree/main/presentations))
#### Examples
-
+
+
- Notebooks that save and load resource descriptions and validate their format (using bioimageio.spec, a dependency of bioimageio.core)
- load_model_and_create_your_own.ipynb
@@ -327,7 +331,6 @@ bioimageio
For examples see [Get started](#get-started).
-
### CLI inputs from file
For convenience the command line options (not arguments) may be given in a `bioimageio-cli.json`
@@ -342,7 +345,6 @@ blockwise: true
stats: inputs/dataset_statistics.json
```
-
## Set up Development Environment
To set up a development conda environment run the following commands:
@@ -355,15 +357,21 @@ pip install -e . --no-deps
There are different environment files available that only install tensorflow or pytorch as dependencies, see [dev folder](https://github.com/bioimage-io/core-bioimage-io-python/tree/main/dev).
-
## Logging level
`bioimageio.spec` and `bioimageio.core` use [loguru](https://github.com/Delgan/loguru) for logging, hence the logging level
may be controlled with the `LOGURU_LEVEL` environment variable.
-
## Changelog
+### 0.7.1 (to be released)
+
+- breaking: removed `decimals` argument from bioimageio CLI and `bioimageio.core.commands.test()`
+- New feature: `bioimageio.core.test_description` accepts **runtime_env** and **run_command** to test a resource
+ using the conda environment described by that resource (or another specified conda env)
+- new CLI command: `bioimageio add-weights` (and utility function: bioimageio.core.add_weights)
+- removed `bioimageio.core.proc_ops.get_proc_class` in favor of `bioimageio.core.proc_ops.get_proc`
+
### 0.7.0
- breaking:
diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py
index 613cd85d..c7554372 100644
--- a/bioimageio/core/__init__.py
+++ b/bioimageio/core/__init__.py
@@ -41,6 +41,7 @@
)
from ._settings import settings
from .axis import Axis, AxisId
+from .backends import create_model_adapter
from .block_meta import BlockMeta
from .common import MemberId
from .prediction import predict, predict_many
@@ -49,6 +50,7 @@
from .stat_measures import Stat
from .tensor import Tensor
from .utils import VERSION
+from .weight_converters import add_weights
__version__ = VERSION
@@ -63,6 +65,7 @@
__all__ = [
"__version__",
+ "add_weights",
"axis",
"Axis",
"AxisId",
@@ -73,6 +76,7 @@
"commands",
"common",
"compute_dataset_measures",
+ "create_model_adapter",
"create_prediction_pipeline",
"digest_spec",
"dump_description",
diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py
index 9da63bf5..436448f5 100644
--- a/bioimageio/core/__main__.py
+++ b/bioimageio/core/__main__.py
@@ -1,4 +1,4 @@
-from bioimageio.core.cli import Bioimageio
+from .cli import Bioimageio
def main():
diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py
index 9f5ccc3e..0b7717aa 100644
--- a/bioimageio/core/_prediction_pipeline.py
+++ b/bioimageio/core/_prediction_pipeline.py
@@ -12,14 +12,21 @@
Union,
)
+from loguru import logger
from tqdm import tqdm
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
-from bioimageio.spec.model.v0_5 import WeightsFormat
from ._op_base import BlockedOperator
from .axis import AxisId, PerAxis
-from .common import Halo, MemberId, PerMember, SampleId
+from .common import (
+ BlocksizeParameter,
+ Halo,
+ MemberId,
+ PerMember,
+ SampleId,
+ SupportedWeightsFormat,
+)
from .digest_spec import (
get_block_transform,
get_input_halo,
@@ -43,7 +50,8 @@
class PredictionPipeline:
"""
Represents model computation including preprocessing and postprocessing
- Note: Ideally use the PredictionPipeline as a context manager
+ Note: Ideally use the `PredictionPipeline` in a with statement
+ (as a context manager).
"""
def __init__(
@@ -54,13 +62,20 @@ def __init__(
preprocessing: List[Processing],
postprocessing: List[Processing],
model_adapter: ModelAdapter,
- default_ns: Union[
- v0_5.ParameterizedSize_N,
- Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
- ] = 10,
+ default_ns: Optional[BlocksizeParameter] = None,
+ default_blocksize_parameter: BlocksizeParameter = 10,
default_batch_size: int = 1,
) -> None:
+ """Use `create_prediction_pipeline` to create a `PredictionPipeline`"""
super().__init__()
+ default_blocksize_parameter = default_ns or default_blocksize_parameter
+ if default_ns is not None:
+ warnings.warn(
+ "Argument `default_ns` is deprecated in favor of"
+ + " `default_blocksize_paramter` and will be removed soon."
+ )
+ del default_ns
+
if model_description.run_mode:
warnings.warn(
f"Not yet implemented inference for run mode '{model_description.run_mode.name}'"
@@ -88,7 +103,7 @@ def __init__(
)
self._block_transform = get_block_transform(model_description)
- self._default_ns = default_ns
+ self._default_blocksize_parameter = default_blocksize_parameter
self._default_batch_size = default_batch_size
self._input_ids = get_member_ids(model_description.inputs)
@@ -121,19 +136,9 @@ def predict_sample_block(
self.apply_preprocessing(sample_block)
output_meta = sample_block.get_transformed_meta(self._block_transform)
- output = output_meta.with_data(
- {
- tid: out
- for tid, out in zip(
- self._output_ids,
- self._adapter.forward(
- *(sample_block.members.get(t) for t in self._input_ids)
- ),
- )
- if out is not None
- },
- stat=sample_block.stat,
- )
+ local_output = self._adapter.forward(sample_block)
+
+ output = output_meta.with_data(local_output.members, stat=local_output.stat)
if not skip_postprocessing:
self.apply_postprocessing(output)
@@ -152,20 +157,7 @@ def predict_sample_without_blocking(
if not skip_preprocessing:
self.apply_preprocessing(sample)
- output = Sample(
- members={
- out_id: out
- for out_id, out in zip(
- self._output_ids,
- self._adapter.forward(
- *(sample.members.get(in_id) for in_id in self._input_ids)
- ),
- )
- if out is not None
- },
- stat=sample.stat,
- id=sample.id,
- )
+ output = self._adapter.forward(sample)
if not skip_postprocessing:
self.apply_postprocessing(output)
@@ -197,9 +189,15 @@ def predict_sample_with_fixed_blocking(
)
input_blocks = list(input_blocks)
predicted_blocks: List[SampleBlock] = []
+ logger.info(
+ "split sample shape {} into {} blocks of {}.",
+ {k: dict(v) for k, v in sample.shape.items()},
+ n_blocks,
+ {k: dict(v) for k, v in input_block_shape.items()},
+ )
for b in tqdm(
input_blocks,
- desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}",
+ desc=f"predict {sample.id or ''} with {self.model_description.id or self.model_description.name}",
unit="block",
unit_divisor=1,
total=n_blocks,
@@ -238,7 +236,7 @@ def predict_sample_with_blocking(
+ " Consider using `predict_sample_with_fixed_blocking`"
)
- ns = ns or self._default_ns
+ ns = ns or self._default_blocksize_parameter
if isinstance(ns, int):
ns = {
(ipt.id, a.id): ns
@@ -319,18 +317,16 @@ def create_prediction_pipeline(
bioimageio_model: AnyModelDescr,
*,
devices: Optional[Sequence[str]] = None,
- weight_format: Optional[WeightsFormat] = None,
- weights_format: Optional[WeightsFormat] = None,
+ weight_format: Optional[SupportedWeightsFormat] = None,
+ weights_format: Optional[SupportedWeightsFormat] = None,
dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(),
keep_updating_initial_dataset_statistics: bool = False,
fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType(
{}
),
model_adapter: Optional[ModelAdapter] = None,
- ns: Union[
- v0_5.ParameterizedSize_N,
- Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
- ] = 10,
+ ns: Optional[BlocksizeParameter] = None,
+ default_blocksize_parameter: BlocksizeParameter = 10,
**deprecated_kwargs: Any,
) -> PredictionPipeline:
"""
@@ -340,9 +336,33 @@ def create_prediction_pipeline(
* model prediction
* computation of output statistics
* postprocessing
+
+ Args:
+ bioimageio_model: A bioimageio model description.
+ devices: (optional)
+ weight_format: deprecated in favor of **weights_format**
+ weights_format: (optional) Use a specific **weights_format** rather than
+ choosing one automatically.
+ A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be
+ created to run inference with the **bioimageio_model**.
+ dataset_for_initial_statistics: (optional) If preprocessing steps require input
+ dataset statistics, **dataset_for_initial_statistics** allows you to
+ specifcy a dataset from which these statistics are computed.
+ keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want
+ to update dataset statistics with each processed sample.
+ fixed_dataset_statistics: (optional) Allows you to specify a mapping of
+ `DatasetMeasure`s to precomputed `MeasureValue`s.
+ model_adapter: (optional) Allows you to use a custom **model_adapter** instead
+ of creating one according to the present/selected **weights_format**.
+ ns: deprecated in favor of **default_blocksize_parameter**
+ default_blocksize_parameter: Allows to control the default block size for
+ blockwise predictions, see `BlocksizeParameter`.
+
"""
weights_format = weight_format or weights_format
del weight_format
+ default_blocksize_parameter = ns or default_blocksize_parameter
+ del ns
if deprecated_kwargs:
warnings.warn(
f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}"
@@ -377,5 +397,5 @@ def dataset():
model_adapter=model_adapter,
preprocessing=preprocessing,
postprocessing=postprocessing,
- default_ns=ns,
+ default_blocksize_parameter=default_blocksize_parameter,
)
diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py
index 6ace6d5c..327e540a 100644
--- a/bioimageio/core/_resource_tests.py
+++ b/bioimageio/core/_resource_tests.py
@@ -1,21 +1,53 @@
-import traceback
+import hashlib
+import os
+import platform
+import subprocess
import warnings
+from io import StringIO
from itertools import product
-from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import (
+ Callable,
+ Dict,
+ Hashable,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+ overload,
+)
-import numpy as np
from loguru import logger
+from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
from bioimageio.spec import (
+ BioimageioCondaEnv,
InvalidDescr,
+ LatestResourceDescr,
ResourceDescr,
+ ValidationContext,
build_description,
dump_description,
+ get_conda_env,
load_description,
+ save_bioimageio_package,
)
+from bioimageio.spec._description_impl import DISCOVER
from bioimageio.spec._internal.common_nodes import ResourceDescrBase
-from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource
-from bioimageio.spec.get_conda_env import get_conda_env
+from bioimageio.spec._internal.io import is_yaml_value
+from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
+from bioimageio.spec._internal.types import (
+ AbsoluteTolerance,
+ FormatVersionPlaceholder,
+ MismatchedElementsPerMillion,
+ RelativeTolerance,
+)
+from bioimageio.spec._internal.validation_context import get_validation_context
+from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import WeightsFormat
from bioimageio.spec.summary import (
@@ -27,19 +59,38 @@
from ._prediction_pipeline import create_prediction_pipeline
from .axis import AxisId, BatchSize
+from .common import MemberId, SupportedWeightsFormat
from .digest_spec import get_test_inputs, get_test_outputs
from .sample import Sample
from .utils import VERSION
-def enable_determinism(mode: Literal["seed_only", "full"]):
+class DeprecatedKwargs(TypedDict):
+ absolute_tolerance: NotRequired[AbsoluteTolerance]
+ relative_tolerance: NotRequired[RelativeTolerance]
+ decimal: NotRequired[Optional[int]]
+
+
+def enable_determinism(
+ mode: Literal["seed_only", "full"] = "full",
+ weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
+):
"""Seed and configure ML frameworks for maximum reproducibility.
May degrade performance. Only recommended for testing reproducibility!
Seed any random generators and (if **mode**=="full") request ML frameworks to use
deterministic algorithms.
+
+ Args:
+ mode: determinism mode
+ - 'seed_only' -- only set seeds, or
+ - 'full' determinsm features (might degrade performance or throw exceptions)
+ weight_formats: Limit deep learning importing deep learning frameworks
+ based on weight_formats.
+ E.g. this allows to avoid importing tensorflow when testing with pytorch.
+
Notes:
- - **mode** == "full" might degrade performance and throw exceptions.
+ - **mode** == "full" might degrade performance or throw exceptions.
- Subsequent inference calls might still differ. Call before each function
(sequence) that is expected to be reproducible.
- Degraded performance: Use for testing reproducibility only!
@@ -58,120 +109,389 @@ def enable_determinism(mode: Literal["seed_only", "full"]):
except Exception as e:
logger.debug(str(e))
- try:
+ if (
+ weight_formats is None
+ or "pytorch_state_dict" in weight_formats
+ or "torchscript" in weight_formats
+ ):
try:
- import torch
- except ImportError:
- pass
- else:
- _ = torch.manual_seed(0)
- torch.use_deterministic_algorithms(mode == "full")
- except Exception as e:
- logger.debug(str(e))
+ try:
+ import torch
+ except ImportError:
+ pass
+ else:
+ _ = torch.manual_seed(0)
+ torch.use_deterministic_algorithms(mode == "full")
+ except Exception as e:
+ logger.debug(str(e))
- try:
+ if (
+ weight_formats is None
+ or "tensorflow_saved_model_bundle" in weight_formats
+ or "keras_hdf5" in weight_formats
+ ):
try:
- import keras
- except ImportError:
- pass
- else:
- keras.utils.set_random_seed(0)
- except Exception as e:
- logger.debug(str(e))
-
- try:
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
+ try:
+ import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
+ except ImportError:
+ pass
+ else:
+ tf.random.set_seed(0)
+ if mode == "full":
+ tf.config.experimental.enable_op_determinism()
+ # TODO: find possibility to switch it off again??
+ except Exception as e:
+ logger.debug(str(e))
+
+ if weight_formats is None or "keras_hdf5" in weight_formats:
try:
- import tensorflow as tf # pyright: ignore[reportMissingImports]
- except ImportError:
- pass
- else:
- tf.random.seed(0)
- if mode == "full":
- tf.config.experimental.enable_op_determinism()
- # TODO: find possibility to switch it off again??
- except Exception as e:
- logger.debug(str(e))
+ try:
+ import keras # pyright: ignore[reportMissingTypeStubs]
+ except ImportError:
+ pass
+ else:
+ keras.utils.set_random_seed(0)
+ except Exception as e:
+ logger.debug(str(e))
def test_model(
- source: Union[v0_5.ModelDescr, PermissiveFileSource],
- weight_format: Optional[WeightsFormat] = None,
+ source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
+ weight_format: Optional[SupportedWeightsFormat] = None,
devices: Optional[List[str]] = None,
- absolute_tolerance: float = 1.5e-4,
- relative_tolerance: float = 1e-4,
- decimal: Optional[int] = None,
*,
determinism: Literal["seed_only", "full"] = "seed_only",
+ sha256: Optional[Sha256] = None,
+ stop_early: bool = False,
+ **deprecated: Unpack[DeprecatedKwargs],
) -> ValidationSummary:
"""Test model inference"""
return test_description(
source,
weight_format=weight_format,
devices=devices,
- absolute_tolerance=absolute_tolerance,
- relative_tolerance=relative_tolerance,
- decimal=decimal,
determinism=determinism,
expected_type="model",
+ sha256=sha256,
+ stop_early=stop_early,
+ **deprecated,
)
+def default_run_command(args: Sequence[str]):
+ logger.info("running '{}'...", " ".join(args))
+ _ = subprocess.run(args, shell=True, text=True, check=True)
+
+
def test_description(
source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
*,
- format_version: Union[Literal["discover", "latest"], str] = "discover",
- weight_format: Optional[WeightsFormat] = None,
+ format_version: Union[FormatVersionPlaceholder, str] = "discover",
+ weight_format: Optional[SupportedWeightsFormat] = None,
devices: Optional[Sequence[str]] = None,
- absolute_tolerance: float = 1.5e-4,
- relative_tolerance: float = 1e-4,
- decimal: Optional[int] = None,
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Optional[str] = None,
+ sha256: Optional[Sha256] = None,
+ stop_early: bool = False,
+ runtime_env: Union[
+ Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
+ ] = ("currently-active"),
+ run_command: Callable[[Sequence[str]], None] = default_run_command,
+ **deprecated: Unpack[DeprecatedKwargs],
) -> ValidationSummary:
- """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
- rd = load_description_and_test(
- source,
- format_version=format_version,
- weight_format=weight_format,
- devices=devices,
- absolute_tolerance=absolute_tolerance,
- relative_tolerance=relative_tolerance,
- decimal=decimal,
- determinism=determinism,
- expected_type=expected_type,
+ """Test a bioimage.io resource dynamically,
+ for example run prediction of test tensors for models.
+
+ Args:
+ source: model description source.
+ weight_format: Weight format to test.
+ Default: All weight formats present in **source**.
+ devices: Devices to test with, e.g. 'cpu', 'cuda'.
+ Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
+ determinism: Modes to improve reproducibility of test outputs.
+ expected_type: Assert an expected resource description `type`.
+ sha256: Expected SHA256 value of **source**.
+ (Ignored if **source** already is a loaded `ResourceDescr` object.)
+ stop_early: Do not run further subtests after a failed one.
+ runtime_env: (Experimental feature!) The Python environment to run the tests in
+ - `"currently-active"`: Use active Python interpreter.
+ - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
+ environment YAML file based on the model weights description.
+ - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
+ Note: The `bioimageio.core` dependency will be added automatically if not present.
+ run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess
+ (ignored if **runtime_env** is `"currently-active"`).
+ """
+ if runtime_env == "currently-active":
+ rd = load_description_and_test(
+ source,
+ format_version=format_version,
+ weight_format=weight_format,
+ devices=devices,
+ determinism=determinism,
+ expected_type=expected_type,
+ sha256=sha256,
+ stop_early=stop_early,
+ **deprecated,
+ )
+ return rd.validation_summary
+
+ if runtime_env == "as-described":
+ conda_env = None
+ elif isinstance(runtime_env, (str, Path)):
+ conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
+ elif isinstance(runtime_env, BioimageioCondaEnv):
+ conda_env = runtime_env
+ else:
+ assert_never(runtime_env)
+
+ with TemporaryDirectory(ignore_cleanup_errors=True) as _d:
+ working_dir = Path(_d)
+ if isinstance(source, (dict, ResourceDescrBase)):
+ file_source = save_bioimageio_package(
+ source, output_path=working_dir / "package.zip"
+ )
+ else:
+ file_source = source
+
+ return _test_in_env(
+ file_source,
+ working_dir=working_dir,
+ weight_format=weight_format,
+ conda_env=conda_env,
+ devices=devices,
+ determinism=determinism,
+ expected_type=expected_type,
+ sha256=sha256,
+ stop_early=stop_early,
+ run_command=run_command,
+ **deprecated,
+ )
+
+
+def _test_in_env(
+ source: PermissiveFileSource,
+ *,
+ working_dir: Path,
+ weight_format: Optional[SupportedWeightsFormat],
+ conda_env: Optional[BioimageioCondaEnv],
+ devices: Optional[Sequence[str]],
+ determinism: Literal["seed_only", "full"],
+ run_command: Callable[[Sequence[str]], None],
+ stop_early: bool,
+ expected_type: Optional[str],
+ sha256: Optional[Sha256],
+ **deprecated: Unpack[DeprecatedKwargs],
+) -> ValidationSummary:
+ descr = load_description(source)
+
+ if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
+ raise NotImplementedError("Not yet implemented for non-model resources")
+
+ if weight_format is None:
+ all_present_wfs = [
+ wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
+ ]
+ ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
+ logger.info(
+ "Found weight formats {}. Start testing all{}...",
+ all_present_wfs,
+ f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
+ )
+ summary = _test_in_env(
+ source,
+ working_dir=working_dir / all_present_wfs[0],
+ weight_format=all_present_wfs[0],
+ devices=devices,
+ determinism=determinism,
+ conda_env=conda_env,
+ run_command=run_command,
+ expected_type=expected_type,
+ sha256=sha256,
+ stop_early=stop_early,
+ **deprecated,
+ )
+ for wf in all_present_wfs[1:]:
+ additional_summary = _test_in_env(
+ source,
+ working_dir=working_dir / wf,
+ weight_format=wf,
+ devices=devices,
+ determinism=determinism,
+ conda_env=conda_env,
+ run_command=run_command,
+ expected_type=expected_type,
+ sha256=sha256,
+ stop_early=stop_early,
+ **deprecated,
+ )
+ for d in additional_summary.details:
+ # TODO: filter reduntant details; group details
+ summary.add_detail(d)
+ return summary
+
+ if weight_format == "pytorch_state_dict":
+ wf = descr.weights.pytorch_state_dict
+ elif weight_format == "torchscript":
+ wf = descr.weights.torchscript
+ elif weight_format == "keras_hdf5":
+ wf = descr.weights.keras_hdf5
+ elif weight_format == "onnx":
+ wf = descr.weights.onnx
+ elif weight_format == "tensorflow_saved_model_bundle":
+ wf = descr.weights.tensorflow_saved_model_bundle
+ elif weight_format == "tensorflow_js":
+ raise RuntimeError(
+ "testing 'tensorflow_js' is not supported by bioimageio.core"
+ )
+ else:
+ assert_never(weight_format)
+
+ assert wf is not None
+ if conda_env is None:
+ conda_env = get_conda_env(entry=wf)
+
+ # remove name as we crate a name based on the env description hash value
+ conda_env.name = None
+
+ dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
+ if not is_yaml_value(dumped_env):
+ raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
+
+ env_io = StringIO()
+ write_yaml(dumped_env, file=env_io)
+ encoded_env = env_io.getvalue().encode()
+ env_name = hashlib.sha256(encoded_env).hexdigest()
+
+ try:
+ run_command(["where" if platform.system() == "Windows" else "which", "conda"])
+ except Exception as e:
+ raise RuntimeError("Conda not available") from e
+
+ working_dir.mkdir(parents=True, exist_ok=True)
+ try:
+ run_command(["conda", "activate", env_name])
+ except Exception:
+ path = working_dir / "env.yaml"
+ _ = path.write_bytes(encoded_env)
+ logger.debug("written conda env to {}", path)
+ run_command(["conda", "env", "create", f"--file={path}", f"--name={env_name}"])
+ run_command(["conda", "activate", env_name])
+
+ summary_path = working_dir / "summary.json"
+ run_command(
+ [
+ "conda",
+ "run",
+ "-n",
+ env_name,
+ "bioimageio",
+ "test",
+ str(source),
+ f"--summary-path={summary_path}",
+ f"--determinism={determinism}",
+ ]
+ + ([f"--expected-type={expected_type}"] if expected_type else [])
+ + (["--stop-early"] if stop_early else [])
)
- return rd.validation_summary
+ return ValidationSummary.model_validate_json(summary_path.read_bytes())
+@overload
def load_description_and_test(
source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
*,
- format_version: Union[Literal["discover", "latest"], str] = "discover",
- weight_format: Optional[WeightsFormat] = None,
+ format_version: Literal["latest"],
+ weight_format: Optional[SupportedWeightsFormat] = None,
devices: Optional[Sequence[str]] = None,
- absolute_tolerance: float = 1.5e-4,
- relative_tolerance: float = 1e-4,
- decimal: Optional[int] = None,
determinism: Literal["seed_only", "full"] = "seed_only",
expected_type: Optional[str] = None,
+ sha256: Optional[Sha256] = None,
+ stop_early: bool = False,
+ **deprecated: Unpack[DeprecatedKwargs],
+) -> Union[LatestResourceDescr, InvalidDescr]: ...
+
+
+@overload
+def load_description_and_test(
+ source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
+ *,
+ format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
+ weight_format: Optional[SupportedWeightsFormat] = None,
+ devices: Optional[Sequence[str]] = None,
+ determinism: Literal["seed_only", "full"] = "seed_only",
+ expected_type: Optional[str] = None,
+ sha256: Optional[Sha256] = None,
+ stop_early: bool = False,
+ **deprecated: Unpack[DeprecatedKwargs],
+) -> Union[ResourceDescr, InvalidDescr]: ...
+
+
+def load_description_and_test(
+ source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
+ *,
+ format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
+ weight_format: Optional[SupportedWeightsFormat] = None,
+ devices: Optional[Sequence[str]] = None,
+ determinism: Literal["seed_only", "full"] = "seed_only",
+ expected_type: Optional[str] = None,
+ sha256: Optional[Sha256] = None,
+ stop_early: bool = False,
+ **deprecated: Unpack[DeprecatedKwargs],
) -> Union[ResourceDescr, InvalidDescr]:
- """Test RDF dynamically, e.g. model inference of test inputs"""
- if (
- isinstance(source, ResourceDescrBase)
- and format_version != "discover"
- and source.format_version != format_version
- ):
- warnings.warn(
- f"deserializing source to ensure we validate and test using format {format_version}"
- )
- source = dump_description(source)
+ """Test a bioimage.io resource dynamically,
+ for example run prediction of test tensors for models.
+
+ See `test_description` for more details.
+
+ Returns:
+ A (possibly invalid) resource description object
+ with a populated `.validation_summary` attribute.
+ """
+ if isinstance(source, ResourceDescrBase):
+ root = source.root
+ file_name = source.file_name
+ if (
+ (
+ format_version
+ not in (
+ DISCOVER,
+ source.format_version,
+ ".".join(source.format_version.split(".")[:2]),
+ )
+ )
+ or (c := source.validation_summary.details[0].context) is None
+ or not c.perform_io_checks
+ ):
+ logger.debug(
+ "deserializing source to ensure we validate and test using format {} and perform io checks",
+ format_version,
+ )
+ source = dump_description(source)
+ else:
+ root = Path()
+ file_name = None
if isinstance(source, ResourceDescrBase):
rd = source
elif isinstance(source, dict):
- rd = build_description(source, format_version=format_version)
+ # check context for a given root; default to root of source
+ context = get_validation_context(
+ ValidationContext(root=root, file_name=file_name)
+ ).replace(
+ perform_io_checks=True # make sure we perform io checks though
+ )
+
+ rd = build_description(
+ source,
+ format_version=format_version,
+ context=context,
+ )
else:
- rd = load_description(source, format_version=format_version)
+ rd = load_description(
+ source, format_version=format_version, sha256=sha256, perform_io_checks=True
+ )
rd.validation_summary.env.add(
InstalledPackage(name="bioimageio.core", version=VERSION)
@@ -182,50 +502,106 @@ def load_description_and_test(
if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
if weight_format is None:
- weight_formats: List[WeightsFormat] = [
+ weight_formats: List[SupportedWeightsFormat] = [
w for w, we in rd.weights if we is not None
] # pyright: ignore[reportAssignmentType]
else:
weight_formats = [weight_format]
- if decimal is None:
- atol = absolute_tolerance
- rtol = relative_tolerance
- else:
- warnings.warn(
- "The argument `decimal` has been deprecated in favour of"
- + " `relative_tolerance` and `absolute_tolerance`, with different"
- + " validation logic, using `numpy.testing.assert_allclose, see"
- + " 'https://numpy.org/doc/stable/reference/generated/"
- + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
- + " will cause validation to revert to the old behaviour."
- )
- atol = 1.5 * 10 ** (-decimal)
- rtol = 0
-
- enable_determinism(determinism)
+ enable_determinism(determinism, weight_formats=weight_formats)
for w in weight_formats:
- _test_model_inference(rd, w, devices, atol, rtol)
+ _test_model_inference(rd, w, devices, **deprecated)
+ if stop_early and rd.validation_summary.status == "failed":
+ break
+
if not isinstance(rd, v0_4.ModelDescr):
- _test_model_inference_parametrized(rd, w, devices)
+ _test_model_inference_parametrized(
+ rd, w, devices, stop_early=stop_early
+ )
+ if stop_early and rd.validation_summary.status == "failed":
+ break
# TODO: add execution of jupyter notebooks
# TODO: add more tests
+ if rd.validation_summary.status == "valid-format":
+ rd.validation_summary.status = "passed"
+
return rd
+def _get_tolerance(
+ model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ wf: SupportedWeightsFormat,
+ m: MemberId,
+ **deprecated: Unpack[DeprecatedKwargs],
+) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
+ if isinstance(model, v0_5.ModelDescr):
+ applicable = v0_5.ReproducibilityTolerance()
+
+ # check legacy test kwargs for weight format specific tolerance
+ if model.config.bioimageio.model_extra is not None:
+ for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
+ "test_kwargs", {}
+ ).items():
+ if wf == weights_format:
+ applicable = v0_5.ReproducibilityTolerance(
+ relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
+ absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4),
+ )
+ break
+
+ # check for weights format and output tensor specific tolerance
+ for a in model.config.bioimageio.reproducibility_tolerance:
+ if (not a.weights_formats or wf in a.weights_formats) and (
+ not a.output_ids or m in a.output_ids
+ ):
+ applicable = a
+ break
+
+ rtol = applicable.relative_tolerance
+ atol = applicable.absolute_tolerance
+ mismatched_tol = applicable.mismatched_elements_per_million
+ elif (decimal := deprecated.get("decimal")) is not None:
+ warnings.warn(
+ "The argument `decimal` has been deprecated in favour of"
+ + " `relative_tolerance` and `absolute_tolerance`, with different"
+ + " validation logic, using `numpy.testing.assert_allclose, see"
+ + " 'https://numpy.org/doc/stable/reference/generated/"
+ + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
+ + " will cause validation to revert to the old behaviour."
+ )
+ atol = 1.5 * 10 ** (-decimal)
+ rtol = 0
+ mismatched_tol = 0
+ else:
+ # use given (deprecated) test kwargs
+ atol = deprecated.get("absolute_tolerance", 1e-5)
+ rtol = deprecated.get("relative_tolerance", 1e-3)
+ mismatched_tol = 0
+
+ return rtol, atol, mismatched_tol
+
+
def _test_model_inference(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- weight_format: WeightsFormat,
+ weight_format: SupportedWeightsFormat,
devices: Optional[Sequence[str]],
- atol: float,
- rtol: float,
+ **deprecated: Unpack[DeprecatedKwargs],
) -> None:
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
- logger.info("starting '{}'", test_name)
- error: Optional[str] = None
- tb: List[str] = []
+ logger.debug("starting '{}'", test_name)
+ errors: List[ErrorEntry] = []
+
+ def add_error_entry(msg: str, with_traceback: bool = False):
+ errors.append(
+ ErrorEntry(
+ loc=("weights", weight_format),
+ msg=msg,
+ type="bioimageio.core",
+ with_traceback=with_traceback,
+ )
+ )
try:
inputs = get_test_inputs(model)
@@ -237,54 +613,66 @@ def _test_model_inference(
results = prediction_pipeline.predict_sample_without_blocking(inputs)
if len(results.members) != len(expected.members):
- error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
+ add_error_entry(
+ f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
+ )
else:
- for m, exp in expected.members.items():
- res = results.members.get(m)
- if res is None:
- error = "Output tensors for test case may not be None"
+ for m, expected in expected.members.items():
+ actual = results.members.get(m)
+ if actual is None:
+ add_error_entry("Output tensors for test case may not be None")
break
- try:
- np.testing.assert_allclose(
- res.data,
- exp.data,
- rtol=rtol,
- atol=atol,
+
+ rtol, atol, mismatched_tol = _get_tolerance(
+ model, wf=weight_format, m=m, **deprecated
+ )
+ mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs(
+ expected
+ )
+ mismatched_elements = mismatched.sum().item()
+ if mismatched_elements / expected.size > mismatched_tol / 1e6:
+ r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
+ r_max = r_diff[r_max_idx].item()
+ r_actual = actual[r_max_idx].item()
+ r_expected = expected[r_max_idx].item()
+ a_max_idx = abs_diff.argmax()
+ a_max = abs_diff[a_max_idx].item()
+ a_actual = actual[a_max_idx].item()
+ a_expected = expected[a_max_idx].item()
+ add_error_entry(
+ f"Output '{m}' disagrees with {mismatched_elements} of"
+ + f" {expected.size} expected values."
+ + f"\n Max relative difference: {r_max:.2e}"
+ + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
+ + f" at {r_max_idx}"
+ + f"\n Max absolute difference: {a_max:.2e}"
+ + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
)
- except AssertionError as e:
- error = f"Output and expected output disagree:\n {e}"
break
except Exception as e:
- error = str(e)
- tb = traceback.format_tb(e.__traceback__)
+ if get_validation_context().raise_errors:
+ raise e
+
+ add_error_entry(str(e), with_traceback=True)
model.validation_summary.add_detail(
ValidationDetail(
name=test_name,
loc=("weights", weight_format),
- status="passed" if error is None else "failed",
+ status="failed" if errors else "passed",
recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
- errors=(
- []
- if error is None
- else [
- ErrorEntry(
- loc=("weights", weight_format),
- msg=error,
- type="bioimageio.core",
- traceback=tb,
- )
- ]
- ),
+ errors=errors,
)
)
def _test_model_inference_parametrized(
model: v0_5.ModelDescr,
- weight_format: WeightsFormat,
+ weight_format: SupportedWeightsFormat,
devices: Optional[Sequence[str]],
+ *,
+ stop_early: bool,
) -> None:
if not any(
isinstance(a.size, v0_5.ParameterizedSize)
@@ -311,11 +699,13 @@ def _test_model_inference_parametrized(
# no batch axis
batch_sizes = {1}
- test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = {
- (n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
+ test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
+ (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
}
logger.info(
- "Testing inference with {} different input tensor sizes", len(test_cases)
+ "Testing inference with {} different inputs (B, N): {}",
+ len(test_cases),
+ test_cases,
)
def generate_test_cases():
@@ -329,7 +719,7 @@ def get_ns(n: int):
if isinstance(a.size, v0_5.ParameterizedSize)
}
- for n, batch_size in sorted(test_cases):
+ for batch_size, n in sorted(test_cases):
input_target_sizes, expected_output_sizes = model.get_axis_sizes(
get_ns(n), batch_size=batch_size
)
@@ -343,12 +733,14 @@ def get_ns(n: int):
resized_test_inputs = Sample(
members={
- t.id: test_inputs.members[t.id].resize_to(
- {
- aid: s
- for (tid, aid), s in input_target_sizes.items()
- if tid == t.id
- },
+ t.id: (
+ test_inputs.members[t.id].resize_to(
+ {
+ aid: s
+ for (tid, aid), s in input_target_sizes.items()
+ if tid == t.id
+ },
+ )
)
for t in model.inputs
},
@@ -422,9 +814,12 @@ def get_ns(n: int):
),
)
)
+ if stop_early and error is not None:
+ break
except Exception as e:
- error = str(e)
- tb = traceback.format_tb(e.__traceback__)
+ if get_validation_context().raise_errors:
+ raise e
+
model.validation_summary.add_detail(
ValidationDetail(
name=f"Run {weight_format} inference for parametrized inputs",
@@ -433,9 +828,9 @@ def get_ns(n: int):
errors=[
ErrorEntry(
loc=("weights", weight_format),
- msg=error,
+ msg=str(e),
type="bioimageio.core",
- traceback=tb,
+ with_traceback=True,
)
],
)
@@ -458,7 +853,7 @@ def _test_expected_resource_type(
ErrorEntry(
loc=("type",),
type="type",
- msg=f"expected type {expected_type}, found {rd.type}",
+ msg=f"Expected type {expected_type}, found {rd.type}",
)
]
),
diff --git a/bioimageio/core/axis.py b/bioimageio/core/axis.py
index 34dfa3e1..0b39045e 100644
--- a/bioimageio/core/axis.py
+++ b/bioimageio/core/axis.py
@@ -8,25 +8,33 @@
from bioimageio.spec.model import v0_5
-def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]):
- if a == "b":
+def _guess_axis_type(a: str):
+ if a in ("b", "batch"):
return "batch"
- elif a == "t":
+ elif a in ("t", "time"):
return "time"
- elif a == "i":
+ elif a in ("i", "index"):
return "index"
- elif a == "c":
+ elif a in ("c", "channel"):
return "channel"
elif a in ("x", "y", "z"):
return "space"
else:
- return "index" # return most unspecific axis
+ raise ValueError(
+ f"Failed to infer axis type for axis id '{a}'."
+ + " Consider using one of: '"
+ + "', '".join(
+ ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"]
+ )
+ + "'. Or creating an `Axis` object instead."
+ )
S = TypeVar("S", bound=str)
AxisId = v0_5.AxisId
+"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'"""
T = TypeVar("T")
PerAxis = Mapping[AxisId, T]
@@ -42,16 +50,22 @@ class Axis:
id: AxisId
type: Literal["batch", "channel", "index", "space", "time"]
+ def __post_init__(self):
+ if self.type == "batch":
+ self.id = AxisId("batch")
+ elif self.type == "channel":
+ self.id = AxisId("channel")
+
@classmethod
def create(cls, axis: AxisLike) -> Axis:
if isinstance(axis, cls):
return axis
elif isinstance(axis, Axis):
return Axis(id=axis.id, type=axis.type)
- elif isinstance(axis, str):
- return Axis(id=AxisId(axis), type=_get_axis_type(axis))
elif isinstance(axis, v0_5.AxisBase):
return Axis(id=AxisId(axis.id), type=axis.type)
+ elif isinstance(axis, str):
+ return Axis(id=AxisId(axis), type=_guess_axis_type(axis))
else:
assert_never(axis)
diff --git a/bioimageio/core/backends/__init__.py b/bioimageio/core/backends/__init__.py
new file mode 100644
index 00000000..c39b58b5
--- /dev/null
+++ b/bioimageio/core/backends/__init__.py
@@ -0,0 +1,3 @@
+from ._model_adapter import create_model_adapter
+
+__all__ = ["create_model_adapter"]
diff --git a/bioimageio/core/backends/_model_adapter.py b/bioimageio/core/backends/_model_adapter.py
new file mode 100644
index 00000000..db4d44e9
--- /dev/null
+++ b/bioimageio/core/backends/_model_adapter.py
@@ -0,0 +1,245 @@
+import sys
+import warnings
+from abc import ABC, abstractmethod
+from typing import (
+ Any,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ final,
+)
+
+from numpy.typing import NDArray
+from typing_extensions import assert_never
+
+from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
+
+from ..common import SupportedWeightsFormat
+from ..digest_spec import get_axes_infos, get_member_ids
+from ..sample import Sample, SampleBlock, SampleBlockWithOrigin
+from ..tensor import Tensor
+
+# Known weight formats in order of priority
+# First match wins
+DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
+ "pytorch_state_dict",
+ "tensorflow_saved_model_bundle",
+ "torchscript",
+ "onnx",
+ "keras_hdf5",
+)
+
+
+class ModelAdapter(ABC):
+ """
+ Represents model *without* any preprocessing or postprocessing.
+
+ ```
+ from bioimageio.core import load_description
+
+ model = load_description(...)
+
+ # option 1:
+ adapter = ModelAdapter.create(model)
+ adapter.forward(...)
+ adapter.unload()
+
+ # option 2:
+ with ModelAdapter.create(model) as adapter:
+ adapter.forward(...)
+ ```
+ """
+
+ def __init__(self, model_description: AnyModelDescr):
+ super().__init__()
+ self._model_descr = model_description
+ self._input_ids = get_member_ids(model_description.inputs)
+ self._output_ids = get_member_ids(model_description.outputs)
+ self._input_axes = [
+ tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
+ ]
+ self._output_axes = [
+ tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
+ ]
+ if isinstance(model_description, v0_4.ModelDescr):
+ self._input_is_optional = [False] * len(model_description.inputs)
+ else:
+ self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
+
+ @final
+ @classmethod
+ def create(
+ cls,
+ model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ *,
+ devices: Optional[Sequence[str]] = None,
+ weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
+ ):
+ """
+ Creates model adapter based on the passed spec
+ Note: All specific adapters should happen inside this function to prevent different framework
+ initializations interfering with each other
+ """
+ if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
+ raise TypeError(
+ f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
+ )
+
+ weights = model_description.weights
+ errors: List[Exception] = []
+ weight_format_priority_order = (
+ DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
+ if weight_format_priority_order is None
+ else weight_format_priority_order
+ )
+ # limit weight formats to the ones present
+ weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
+ w for w in weight_format_priority_order if getattr(weights, w) is not None
+ ]
+ if not weight_format_priority_order_present:
+ raise ValueError(
+ f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
+ )
+
+ for wf in weight_format_priority_order_present:
+ if wf == "pytorch_state_dict":
+ assert weights.pytorch_state_dict is not None
+ try:
+ from .pytorch_backend import PytorchModelAdapter
+
+ return PytorchModelAdapter(
+ model_description=model_description, devices=devices
+ )
+ except Exception as e:
+ errors.append(e)
+ elif wf == "tensorflow_saved_model_bundle":
+ assert weights.tensorflow_saved_model_bundle is not None
+ try:
+ from .tensorflow_backend import create_tf_model_adapter
+
+ return create_tf_model_adapter(
+ model_description=model_description, devices=devices
+ )
+ except Exception as e:
+ errors.append(e)
+ elif wf == "onnx":
+ assert weights.onnx is not None
+ try:
+ from .onnx_backend import ONNXModelAdapter
+
+ return ONNXModelAdapter(
+ model_description=model_description, devices=devices
+ )
+ except Exception as e:
+ errors.append(e)
+ elif wf == "torchscript":
+ assert weights.torchscript is not None
+ try:
+ from .torchscript_backend import TorchscriptModelAdapter
+
+ return TorchscriptModelAdapter(
+ model_description=model_description, devices=devices
+ )
+ except Exception as e:
+ errors.append(e)
+ elif wf == "keras_hdf5":
+ assert weights.keras_hdf5 is not None
+ # keras can either be installed as a separate package or used as part of tensorflow
+ # we try to first import the keras model adapter using the separate package and,
+ # if it is not available, try to load the one using tf
+ try:
+ try:
+ from .keras_backend import KerasModelAdapter
+ except Exception:
+ from .tensorflow_backend import KerasModelAdapter
+
+ return KerasModelAdapter(
+ model_description=model_description, devices=devices
+ )
+ except Exception as e:
+ errors.append(e)
+ else:
+ assert_never(wf)
+
+ assert errors
+ if len(weight_format_priority_order) == 1:
+ assert len(errors) == 1
+ raise errors[0]
+
+ else:
+ msg = (
+ "None of the weight format specific model adapters could be created"
+ + " in this environment."
+ )
+ if sys.version_info[:2] >= (3, 11):
+ raise ExceptionGroup(msg, errors)
+ else:
+ raise ValueError(msg) from Exception(errors)
+
+ @final
+ def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
+ warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
+
+ def forward(
+ self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
+ ) -> Sample:
+ """
+ Run forward pass of model to get model predictions
+
+ Note: sample id and stample stat attributes are passed through
+ """
+ unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
+ if unexpected:
+ warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
+
+ input_arrays = [
+ (
+ None
+ if (a := input_sample.members.get(in_id)) is None
+ else a.transpose(in_order).data.data
+ )
+ for in_id, in_order in zip(self._input_ids, self._input_axes)
+ ]
+ output_arrays = self._forward_impl(input_arrays)
+ assert len(output_arrays) <= len(self._output_ids)
+ output_tensors = [
+ None if a is None else Tensor(a, dims=d)
+ for a, d in zip(output_arrays, self._output_axes)
+ ]
+ return Sample(
+ members={
+ tid: out
+ for tid, out in zip(
+ self._output_ids,
+ output_tensors,
+ )
+ if out is not None
+ },
+ stat=input_sample.stat,
+ id=(
+ input_sample.id
+ if isinstance(input_sample, Sample)
+ else input_sample.sample_id
+ ),
+ )
+
+ @abstractmethod
+ def _forward_impl(
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
+ """framework specific forward implementation"""
+
+ @abstractmethod
+ def unload(self):
+ """
+ Unload model from any devices, freeing their memory.
+ The moder adapter should be considered unusable afterwards.
+ """
+
+ def _get_input_args_numpy(self, input_sample: Sample):
+ """helper to extract tensor args as transposed numpy arrays"""
+
+
+create_model_adapter = ModelAdapter.create
diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/backends/keras_backend.py
similarity index 60%
rename from bioimageio/core/model_adapters/_keras_model_adapter.py
rename to bioimageio/core/backends/keras_backend.py
index e6864ccc..1c10da7d 100644
--- a/bioimageio/core/model_adapters/_keras_model_adapter.py
+++ b/bioimageio/core/backends/keras_backend.py
@@ -1,39 +1,33 @@
import os
-from typing import Any, List, Optional, Sequence, Union
+from typing import Any, Optional, Sequence, Union
from loguru import logger
from numpy.typing import NDArray
-from bioimageio.spec._internal.io_utils import download
+from bioimageio.spec._internal.io import download
+from bioimageio.spec._internal.type_guards import is_list, is_tuple
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import Version
from .._settings import settings
from ..digest_spec import get_axes_infos
-from ..tensor import Tensor
from ._model_adapter import ModelAdapter
os.environ["KERAS_BACKEND"] = settings.keras_backend
# by default, we use the keras integrated with tensorflow
+# TODO: check if we should prefer keras
try:
- import tensorflow as tf # pyright: ignore[reportMissingImports]
- from tensorflow import ( # pyright: ignore[reportMissingImports]
- keras, # pyright: ignore[reportUnknownVariableType]
+ import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
+ from tensorflow import ( # pyright: ignore[reportMissingTypeStubs]
+ keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
)
- tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType]
+ tf_version = Version(tf.__version__)
except Exception:
- try:
- import keras # pyright: ignore[reportMissingImports]
- except Exception as e:
- keras = None
- keras_error = str(e)
- else:
- keras_error = None
+ import keras # pyright: ignore[reportMissingTypeStubs]
+
tf_version = None
-else:
- keras_error = None
class KerasModelAdapter(ModelAdapter):
@@ -43,10 +37,7 @@ def __init__(
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
devices: Optional[Sequence[str]] = None,
) -> None:
- if keras is None:
- raise ImportError(f"failed to import keras: {keras_error}")
-
- super().__init__()
+ super().__init__(model_description=model_description)
if model_description.weights.keras_hdf5 is None:
raise ValueError("model has not keras_hdf5 weights specified")
model_tf_version = model_description.weights.keras_hdf5.tensorflow_version
@@ -84,22 +75,14 @@ def __init__(
for out in model_description.outputs
]
- def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
- _result: Union[Sequence[NDArray[Any]], NDArray[Any]]
- _result = self._network.predict( # pyright: ignore[reportUnknownVariableType]
- *[None if t is None else t.data.data for t in input_tensors]
- )
- if isinstance(_result, (tuple, list)):
- result: Sequence[NDArray[Any]] = _result
+ def _forward_impl( # pyright: ignore[reportUnknownParameterType]
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ):
+ network_output = self._network.predict(*input_arrays) # type: ignore
+ if is_list(network_output) or is_tuple(network_output):
+ return network_output
else:
- result = [_result] # type: ignore
-
- assert len(result) == len(self._output_axes)
- ret: List[Optional[Tensor]] = []
- ret.extend(
- [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)]
- )
- return ret
+ return [network_output] # pyright: ignore[reportUnknownVariableType]
def unload(self) -> None:
logger.warning(
diff --git a/bioimageio/core/backends/onnx_backend.py b/bioimageio/core/backends/onnx_backend.py
new file mode 100644
index 00000000..d5b89152
--- /dev/null
+++ b/bioimageio/core/backends/onnx_backend.py
@@ -0,0 +1,53 @@
+# pyright: reportUnknownVariableType=false
+import warnings
+from typing import Any, List, Optional, Sequence, Union
+
+import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
+from numpy.typing import NDArray
+
+from bioimageio.spec._internal.type_guards import is_list, is_tuple
+from bioimageio.spec.model import v0_4, v0_5
+from bioimageio.spec.utils import download
+
+from ..model_adapters import ModelAdapter
+
+
+class ONNXModelAdapter(ModelAdapter):
+ def __init__(
+ self,
+ *,
+ model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ devices: Optional[Sequence[str]] = None,
+ ):
+ super().__init__(model_description=model_description)
+
+ if model_description.weights.onnx is None:
+ raise ValueError("No ONNX weights specified for {model_description.name}")
+
+ local_path = download(model_description.weights.onnx.source).path
+ self._session = rt.InferenceSession(local_path.read_bytes())
+ onnx_inputs = self._session.get_inputs()
+ self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
+
+ if devices is not None:
+ warnings.warn(
+ f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
+ )
+
+ def _forward_impl(
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ) -> List[Optional[NDArray[Any]]]:
+ result: Any = self._session.run(
+ None, dict(zip(self._input_names, input_arrays))
+ )
+ if is_list(result) or is_tuple(result):
+ result_seq = list(result)
+ else:
+ result_seq = [result]
+
+ return result_seq
+
+ def unload(self) -> None:
+ warnings.warn(
+ "Device management is not implemented for onnx yet, cannot unload model"
+ )
diff --git a/bioimageio/core/backends/pytorch_backend.py b/bioimageio/core/backends/pytorch_backend.py
new file mode 100644
index 00000000..af1ea85d
--- /dev/null
+++ b/bioimageio/core/backends/pytorch_backend.py
@@ -0,0 +1,180 @@
+import gc
+import warnings
+from contextlib import nullcontext
+from io import TextIOWrapper
+from pathlib import Path
+from typing import Any, List, Literal, Optional, Sequence, Union
+
+import torch
+from loguru import logger
+from numpy.typing import NDArray
+from torch import nn
+from typing_extensions import assert_never
+
+from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple
+from bioimageio.spec.common import ZipPath
+from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
+from bioimageio.spec.utils import download
+
+from ..digest_spec import import_callable
+from ._model_adapter import ModelAdapter
+
+
+class PytorchModelAdapter(ModelAdapter):
+ def __init__(
+ self,
+ *,
+ model_description: AnyModelDescr,
+ devices: Optional[Sequence[Union[str, torch.device]]] = None,
+ mode: Literal["eval", "train"] = "eval",
+ ):
+ super().__init__(model_description=model_description)
+ weights = model_description.weights.pytorch_state_dict
+ if weights is None:
+ raise ValueError("No `pytorch_state_dict` weights found")
+
+ devices = get_devices(devices)
+ self._model = load_torch_model(weights, load_state=True, devices=devices)
+ if mode == "eval":
+ self._model = self._model.eval()
+ elif mode == "train":
+ self._model = self._model.train()
+ else:
+ assert_never(mode)
+
+ self._mode: Literal["eval", "train"] = mode
+ self._primary_device = devices[0]
+
+ def _forward_impl(
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ) -> List[Optional[NDArray[Any]]]:
+ tensors = [
+ None if a is None else torch.from_numpy(a).to(self._primary_device)
+ for a in input_arrays
+ ]
+
+ if self._mode == "eval":
+ ctxt = torch.no_grad
+ elif self._mode == "train":
+ ctxt = nullcontext
+ else:
+ assert_never(self._mode)
+
+ with ctxt():
+ model_out = self._model(*tensors)
+
+ if is_tuple(model_out) or is_list(model_out):
+ model_out_seq = model_out
+ else:
+ model_out_seq = model_out = [model_out]
+
+ result: List[Optional[NDArray[Any]]] = []
+ for i, r in enumerate(model_out_seq):
+ if r is None:
+ result.append(None)
+ elif isinstance(r, torch.Tensor):
+ r_np: NDArray[Any] = r.detach().cpu().numpy()
+ result.append(r_np)
+ elif is_ndarray(r):
+ result.append(r)
+ else:
+ raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.")
+
+ return result
+
+ def unload(self) -> None:
+ del self._model
+ _ = gc.collect() # deallocate memory
+ assert torch is not None
+ torch.cuda.empty_cache() # release reserved memory
+
+
+def load_torch_model(
+ weight_spec: Union[
+ v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
+ ],
+ *,
+ load_state: bool = True,
+ devices: Optional[Sequence[Union[str, torch.device]]] = None,
+) -> nn.Module:
+ custom_callable = import_callable(
+ weight_spec.architecture,
+ sha256=(
+ weight_spec.architecture_sha256
+ if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
+ else weight_spec.sha256
+ ),
+ )
+ model_kwargs = (
+ weight_spec.kwargs
+ if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
+ else weight_spec.architecture.kwargs
+ )
+ torch_model = custom_callable(**model_kwargs)
+
+ if not isinstance(torch_model, nn.Module):
+ if isinstance(
+ weight_spec.architecture,
+ (v0_4.CallableFromFile, v0_4.CallableFromDepencency),
+ ):
+ callable_name = weight_spec.architecture.callable_name
+ else:
+ callable_name = weight_spec.architecture.callable
+
+ raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
+
+ if load_state or devices:
+ use_devices = get_devices(devices)
+ torch_model = torch_model.to(use_devices[0])
+ if load_state:
+ torch_model = load_torch_state_dict(
+ torch_model,
+ path=download(weight_spec).path,
+ devices=use_devices,
+ )
+ return torch_model
+
+
+def load_torch_state_dict(
+ model: nn.Module,
+ path: Union[Path, ZipPath],
+ devices: Sequence[torch.device],
+) -> nn.Module:
+ model = model.to(devices[0])
+ with path.open("rb") as f:
+ assert not isinstance(f, TextIOWrapper)
+ state = torch.load(f, map_location=devices[0], weights_only=True)
+
+ incompatible = model.load_state_dict(state)
+ if (
+ incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
+ and incompatible.missing_keys
+ ):
+ logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
+
+ if (
+ incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
+ and incompatible.unexpected_keys
+ ):
+ logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys)
+
+ return model
+
+
+def get_devices(
+ devices: Optional[Sequence[Union[torch.device, str]]] = None,
+) -> List[torch.device]:
+ if not devices:
+ torch_devices = [
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ ]
+ else:
+ torch_devices = [torch.device(d) for d in devices]
+
+ if len(torch_devices) > 1:
+ warnings.warn(
+ f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}"
+ )
+ torch_devices = torch_devices[:1]
+
+ return torch_devices
diff --git a/bioimageio/core/backends/tensorflow_backend.py b/bioimageio/core/backends/tensorflow_backend.py
new file mode 100644
index 00000000..99efe9ef
--- /dev/null
+++ b/bioimageio/core/backends/tensorflow_backend.py
@@ -0,0 +1,212 @@
+from pathlib import Path
+from typing import Any, Optional, Sequence, Union
+
+import numpy as np
+import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
+from loguru import logger
+from numpy.typing import NDArray
+
+from bioimageio.core.io import ensure_unzipped
+from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
+
+from ._model_adapter import ModelAdapter
+
+
+class TensorflowModelAdapter(ModelAdapter):
+ weight_format = "tensorflow_saved_model_bundle"
+
+ def __init__(
+ self,
+ *,
+ model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ devices: Optional[Sequence[str]] = None,
+ ):
+ super().__init__(model_description=model_description)
+
+ weight_file = model_description.weights.tensorflow_saved_model_bundle
+ if model_description.weights.tensorflow_saved_model_bundle is None:
+ raise ValueError("No `tensorflow_saved_model_bundle` weights found")
+
+ if devices is not None:
+ logger.warning(
+ f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
+ )
+
+ # TODO: check how to load tf weights without unzipping
+ weight_file = ensure_unzipped(
+ model_description.weights.tensorflow_saved_model_bundle.source,
+ Path("bioimageio_unzipped_tf_weights"),
+ )
+ self._network = str(weight_file)
+
+ # TODO currently we relaod the model every time. it would be better to keep the graph and session
+ # alive in between of forward passes (but then the sessions need to be properly opened / closed)
+ def _forward_impl( # pyright: ignore[reportUnknownParameterType]
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ):
+ # TODO read from spec
+ tag = ( # pyright: ignore[reportUnknownVariableType]
+ tf.saved_model.tag_constants.SERVING # pyright: ignore[reportAttributeAccessIssue]
+ )
+ signature_key = ( # pyright: ignore[reportUnknownVariableType]
+ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pyright: ignore[reportAttributeAccessIssue]
+ )
+
+ graph = tf.Graph()
+ with graph.as_default():
+ with tf.Session( # pyright: ignore[reportAttributeAccessIssue]
+ graph=graph
+ ) as sess: # pyright: ignore[reportUnknownVariableType]
+ # load the model and the signature
+ graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
+ sess, [tag], self._network
+ )
+ signature = ( # pyright: ignore[reportUnknownVariableType]
+ graph_def.signature_def
+ )
+
+ # get the tensors into the graph
+ in_names = [ # pyright: ignore[reportUnknownVariableType]
+ signature[signature_key].inputs[key].name for key in self._input_ids
+ ]
+ out_names = [ # pyright: ignore[reportUnknownVariableType]
+ signature[signature_key].outputs[key].name
+ for key in self._output_ids
+ ]
+ in_tf_tensors = [
+ graph.get_tensor_by_name(
+ name # pyright: ignore[reportUnknownArgumentType]
+ )
+ for name in in_names # pyright: ignore[reportUnknownVariableType]
+ ]
+ out_tf_tensors = [
+ graph.get_tensor_by_name(
+ name # pyright: ignore[reportUnknownArgumentType]
+ )
+ for name in out_names # pyright: ignore[reportUnknownVariableType]
+ ]
+
+ # run prediction
+ res = sess.run( # pyright: ignore[reportUnknownVariableType]
+ dict(
+ zip(
+ out_names, # pyright: ignore[reportUnknownArgumentType]
+ out_tf_tensors,
+ )
+ ),
+ dict(zip(in_tf_tensors, input_arrays)),
+ )
+ # from dict to list of tensors
+ res = [ # pyright: ignore[reportUnknownVariableType]
+ res[out]
+ for out in out_names # pyright: ignore[reportUnknownVariableType]
+ ]
+
+ return res # pyright: ignore[reportUnknownVariableType]
+
+ def unload(self) -> None:
+ logger.warning(
+ "Device management is not implemented for tensorflow 1, cannot unload model"
+ )
+
+
+class KerasModelAdapter(ModelAdapter):
+ def __init__(
+ self,
+ *,
+ model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ devices: Optional[Sequence[str]] = None,
+ ):
+ if model_description.weights.tensorflow_saved_model_bundle is None:
+ raise ValueError("No `tensorflow_saved_model_bundle` weights found")
+
+ super().__init__(model_description=model_description)
+ if devices is not None:
+ logger.warning(
+ f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
+ )
+
+ # TODO: check how to load tf weights without unzipping
+ weight_file = ensure_unzipped(
+ model_description.weights.tensorflow_saved_model_bundle.source,
+ Path("bioimageio_unzipped_tf_weights"),
+ )
+
+ try:
+ self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue]
+ weight_file,
+ call_endpoint="serve",
+ )
+ except Exception as e:
+ try:
+ self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue]
+ weight_file, call_endpoint="serving_default"
+ )
+ except Exception as ee:
+ logger.opt(exception=ee).info(
+ "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
+ )
+ raise e
+
+ def _forward_impl( # pyright: ignore[reportUnknownParameterType]
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ):
+ assert tf is not None
+ tf_tensor = [
+ None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays
+ ]
+
+ result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
+
+ assert isinstance(result, dict)
+
+ # TODO: Use RDF's `outputs[i].id` here
+ result = list( # pyright: ignore[reportUnknownVariableType]
+ result.values() # pyright: ignore[reportUnknownArgumentType]
+ )
+
+ return [ # pyright: ignore[reportUnknownVariableType]
+ (None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
+ for r in result # pyright: ignore[reportUnknownVariableType]
+ ]
+
+ def unload(self) -> None:
+ logger.warning(
+ "Device management is not implemented for tensorflow>=2 models"
+ + f" using `{self.__class__.__name__}`, cannot unload model"
+ )
+
+
+def create_tf_model_adapter(
+ model_description: AnyModelDescr, devices: Optional[Sequence[str]]
+):
+ tf_version = v0_5.Version(tf.__version__)
+ weights = model_description.weights.tensorflow_saved_model_bundle
+ if weights is None:
+ raise ValueError("No `tensorflow_saved_model_bundle` weights found")
+
+ model_tf_version = weights.tensorflow_version
+ if model_tf_version is None:
+ logger.warning(
+ "The model does not specify the tensorflow version."
+ + f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
+ )
+ elif model_tf_version > tf_version:
+ logger.warning(
+ f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
+ )
+ elif (model_tf_version.major, model_tf_version.minor) != (
+ tf_version.major,
+ tf_version.minor,
+ ):
+ logger.warning(
+ "The tensorflow version specified by the model does not match the installed: "
+ + f"{model_tf_version} != {tf_version}."
+ )
+
+ if tf_version.major <= 1:
+ return TensorflowModelAdapter(
+ model_description=model_description, devices=devices
+ )
+ else:
+ return KerasModelAdapter(model_description=model_description, devices=devices)
diff --git a/bioimageio/core/backends/torchscript_backend.py b/bioimageio/core/backends/torchscript_backend.py
new file mode 100644
index 00000000..ce3ba131
--- /dev/null
+++ b/bioimageio/core/backends/torchscript_backend.py
@@ -0,0 +1,74 @@
+# pyright: reportUnknownVariableType=false
+import gc
+import warnings
+from typing import Any, List, Optional, Sequence, Union
+
+import torch
+from numpy.typing import NDArray
+
+from bioimageio.spec._internal.type_guards import is_list, is_tuple
+from bioimageio.spec.model import v0_4, v0_5
+from bioimageio.spec.utils import download
+
+from ..model_adapters import ModelAdapter
+
+
+class TorchscriptModelAdapter(ModelAdapter):
+ def __init__(
+ self,
+ *,
+ model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
+ devices: Optional[Sequence[str]] = None,
+ ):
+ super().__init__(model_description=model_description)
+ if model_description.weights.torchscript is None:
+ raise ValueError(
+ f"No torchscript weights found for model {model_description.name}"
+ )
+
+ weight_path = download(model_description.weights.torchscript.source).path
+ if devices is None:
+ self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
+ else:
+ self.devices = [torch.device(d) for d in devices]
+
+ if len(self.devices) > 1:
+ warnings.warn(
+ "Multiple devices for single torchscript model not yet implemented"
+ )
+
+ with weight_path.open("rb") as f:
+ self._model = torch.jit.load(f)
+
+ self._model.to(self.devices[0])
+ self._model = self._model.eval()
+
+ def _forward_impl(
+ self, input_arrays: Sequence[Optional[NDArray[Any]]]
+ ) -> List[Optional[NDArray[Any]]]:
+
+ with torch.no_grad():
+ torch_tensor = [
+ None if a is None else torch.from_numpy(a).to(self.devices[0])
+ for a in input_arrays
+ ]
+ output: Any = self._model.forward(*torch_tensor)
+ if is_list(output) or is_tuple(output):
+ output_seq: Sequence[Any] = output
+ else:
+ output_seq = [output]
+
+ return [
+ (
+ None
+ if r is None
+ else r.cpu().numpy() if isinstance(r, torch.Tensor) else r
+ )
+ for r in output_seq
+ ]
+
+ def unload(self) -> None:
+ self._devices = None
+ del self._model
+ _ = gc.collect() # deallocate memory
+ torch.cuda.empty_cache() # release reserved memory
diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py
index f7740092..4e40c1cf 100644
--- a/bioimageio/core/block_meta.py
+++ b/bioimageio/core/block_meta.py
@@ -258,6 +258,7 @@ def split_shape_into_blocks(
set(block_shape),
)
if any(shape[a] < block_shape[a] for a in block_shape):
+ # TODO: allow larger blockshape
raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
assert all(a in shape for a in halo), (tuple(shape), set(halo))
diff --git a/bioimageio/core/cli.py b/bioimageio/core/cli.py
index fad44ab3..8e62239d 100644
--- a/bioimageio/core/cli.py
+++ b/bioimageio/core/cli.py
@@ -8,9 +8,11 @@
import shutil
import subprocess
import sys
+from abc import ABC
from argparse import RawTextHelpFormatter
from difflib import SequenceMatcher
from functools import cached_property
+from io import StringIO
from pathlib import Path
from pprint import pformat, pprint
from typing import (
@@ -18,6 +20,7 @@
Dict,
Iterable,
List,
+ Literal,
Mapping,
Optional,
Sequence,
@@ -27,8 +30,9 @@
Union,
)
+import rich.markdown
from loguru import logger
-from pydantic import BaseModel, Field, model_validator
+from pydantic import AliasChoices, BaseModel, Field, model_validator
from pydantic_settings import (
BaseSettings,
CliPositionalArg,
@@ -39,26 +43,30 @@
SettingsConfigDict,
YamlConfigSettingsSource,
)
-from ruyaml import YAML
from tqdm import tqdm
from typing_extensions import assert_never
-from bioimageio.spec import AnyModelDescr, InvalidDescr, load_description
+from bioimageio.spec import (
+ AnyModelDescr,
+ InvalidDescr,
+ ResourceDescr,
+ load_description,
+ save_bioimageio_yaml_only,
+ settings,
+ update_format,
+ update_hashes,
+)
+from bioimageio.spec._internal.io import is_yaml_value
from bioimageio.spec._internal.io_basics import ZipPath
+from bioimageio.spec._internal.io_utils import open_bioimageio_yaml
from bioimageio.spec._internal.types import NotEmpty
from bioimageio.spec.dataset import DatasetDescr
from bioimageio.spec.model import ModelDescr, v0_4, v0_5
from bioimageio.spec.notebook import NotebookDescr
-from bioimageio.spec.utils import download, ensure_description_is_model
-
-from .commands import (
- WeightFormatArgAll,
- WeightFormatArgAny,
- package,
- test,
- validate_format,
-)
-from .common import MemberId, SampleId
+from bioimageio.spec.utils import download, ensure_description_is_model, write_yaml
+
+from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
+from .common import MemberId, SampleId, SupportedWeightsFormat
from .digest_spec import get_member_ids, load_sample_for_model
from .io import load_dataset_stat, save_dataset_stat, save_sample
from .prediction import create_prediction_pipeline
@@ -71,9 +79,15 @@
)
from .sample import Sample
from .stat_measures import Stat
-from .utils import VERSION
-
-yaml = YAML(typ="safe")
+from .utils import VERSION, compare
+from .weight_converters._add_weights import add_weights
+
+WEIGHT_FORMAT_ALIASES = AliasChoices(
+ "weight-format",
+ "weights-format",
+ "weight_format",
+ "weights_format",
+)
class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
@@ -84,9 +98,31 @@ class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True
pass
+class WithSummaryLogging(ArgMixin):
+ summary: Union[
+ Literal["display"], Path, Sequence[Union[Literal["display"], Path]]
+ ] = Field(
+ "display",
+ examples=[
+ "display",
+ Path("summary.md"),
+ Path("bioimageio_summaries/"),
+ ["display", Path("summary.md")],
+ ],
+ )
+ """Display the validation summary or save it as JSON, Markdown or HTML.
+ The format is chosen based on the suffix: `.json`, `.md`, `.html`.
+ If a folder is given (path w/o suffix) the summary is saved in all formats.
+ Choose/add `"display"` to render the validation summary to the terminal.
+ """
+
+ def log(self, descr: Union[ResourceDescr, InvalidDescr]):
+ _ = descr.validation_summary.log(self.summary)
+
+
class WithSource(ArgMixin):
source: CliPositionalArg[str]
- """Url/path to a bioimageio.yaml/rdf.yaml file
+ """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file
or a bioimage.io resource identifier, e.g. 'affable-shark'"""
@cached_property
@@ -100,29 +136,49 @@ def descr_id(self) -> str:
"""
if isinstance(self.descr, InvalidDescr):
return str(getattr(self.descr, "id", getattr(self.descr, "name")))
- else:
- return str(
- (
- (bio_config := self.descr.config.get("bioimageio", {}))
- and isinstance(bio_config, dict)
- and bio_config.get("nickname")
- )
- or self.descr.id
- or self.descr.name
- )
+
+ nickname = None
+ if (
+ isinstance(self.descr.config, v0_5.Config)
+ and (bio_config := self.descr.config.bioimageio)
+ and bio_config.model_extra is not None
+ ):
+ nickname = bio_config.model_extra.get("nickname")
+
+ return str(nickname or self.descr.id or self.descr.name)
-class ValidateFormatCmd(CmdBase, WithSource):
- """validate the meta data format of a bioimageio resource."""
+class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging):
+ """Validate the meta data format of a bioimageio resource."""
+
+ perform_io_checks: bool = Field(
+ settings.perform_io_checks, alias="perform-io-checks"
+ )
+ """Wether or not to perform validations that requires downloading remote files.
+ Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable.
+ """
+
+ @cached_property
+ def descr(self):
+ return load_description(self.source, perform_io_checks=self.perform_io_checks)
def run(self):
- sys.exit(validate_format(self.descr))
+ self.log(self.descr)
+ sys.exit(
+ 0
+ if self.descr.validation_summary.status in ("valid-format", "passed")
+ else 1
+ )
-class TestCmd(CmdBase, WithSource):
- """Test a bioimageio resource (beyond meta data formatting)"""
+class TestCmd(CmdBase, WithSource, WithSummaryLogging):
+ """Test a bioimageio resource (beyond meta data formatting)."""
- weight_format: WeightFormatArgAll = "all"
+ weight_format: WeightFormatArgAll = Field(
+ "all",
+ alias="weight-format",
+ validation_alias=WEIGHT_FORMAT_ALIASES,
+ )
"""The weight format to limit testing to.
(only relevant for model resources)"""
@@ -130,8 +186,24 @@ class TestCmd(CmdBase, WithSource):
devices: Optional[Union[str, Sequence[str]]] = None
"""Device(s) to use for testing"""
- decimal: int = 4
- """Precision for numerical comparisons"""
+ runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field(
+ "currently-active", alias="runtime-env"
+ )
+ """The python environment to run the tests in
+ - `"currently-active"`: use active Python interpreter
+ - `"as-described"`: generate a conda environment YAML file based on the model
+ weights description.
+ - A path to a conda environment YAML.
+ Note: The `bioimageio.core` dependency will be added automatically if not present.
+ """
+
+ determinism: Literal["seed_only", "full"] = "seed_only"
+ """Modes to improve reproducibility of test outputs."""
+
+ stop_early: bool = Field(
+ False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x")
+ )
+ """Do not run further subtests after a failed one."""
def run(self):
sys.exit(
@@ -139,26 +211,32 @@ def run(self):
self.descr,
weight_format=self.weight_format,
devices=self.devices,
- decimal=self.decimal,
+ summary=self.summary,
+ runtime_env=self.runtime_env,
+ determinism=self.determinism,
)
)
-class PackageCmd(CmdBase, WithSource):
- """save a resource's metadata with its associated files."""
+class PackageCmd(CmdBase, WithSource, WithSummaryLogging):
+ """Save a resource's metadata with its associated files."""
path: CliPositionalArg[Path]
"""The path to write the (zipped) package to.
If it does not have a `.zip` suffix
this command will save the package as an unzipped folder instead."""
- weight_format: WeightFormatArgAll = "all"
+ weight_format: WeightFormatArgAll = Field(
+ "all",
+ alias="weight-format",
+ validation_alias=WEIGHT_FORMAT_ALIASES,
+ )
"""The weight format to include in the package (for model descriptions only)."""
def run(self):
if isinstance(self.descr, InvalidDescr):
- self.descr.validation_summary.display()
- raise ValueError("resource description is invalid")
+ self.log(self.descr)
+ raise ValueError(f"Invalid {self.descr.type} description.")
sys.exit(
package(
@@ -182,7 +260,7 @@ def _get_stat(
req_dataset_meas, _ = get_required_dataset_measures(model_descr)
if stats_path.exists():
- logger.info(f"loading precomputed dataset measures from {stats_path}")
+ logger.info("loading precomputed dataset measures from {}", stats_path)
stat = load_dataset_stat(stats_path)
for m in req_dataset_meas:
if m not in stat:
@@ -203,6 +281,110 @@ def _get_stat(
return stat
+class UpdateCmdBase(CmdBase, WithSource, ABC):
+ output: Union[Literal["display", "stdout"], Path] = "display"
+ """Output updated bioimageio.yaml to the terminal or write to a file.
+ Notes:
+ - `"display"`: Render to the terminal with syntax highlighting.
+ - `"stdout"`: Write to sys.stdout without syntax highligthing.
+ (More convenient for copying the updated bioimageio.yaml from the terminal.)
+ """
+
+ diff: Union[bool, Path] = Field(True, alias="diff")
+ """Output a diff of original and updated bioimageio.yaml.
+ If a given path has an `.html` extension, a standalone HTML file is written,
+ otherwise the diff is saved in unified diff format (pure text).
+ """
+
+ exclude_unset: bool = Field(True, alias="exclude-unset")
+ """Exclude fields that have not explicitly be set."""
+
+ exclude_defaults: bool = Field(False, alias="exclude-defaults")
+ """Exclude fields that have the default value (even if set explicitly)."""
+
+ @cached_property
+ def updated(self) -> Union[ResourceDescr, InvalidDescr]:
+ raise NotImplementedError
+
+ def run(self):
+ original_yaml = open_bioimageio_yaml(self.source).unparsed_content
+ assert isinstance(original_yaml, str)
+ stream = StringIO()
+
+ save_bioimageio_yaml_only(
+ self.updated,
+ stream,
+ exclude_unset=self.exclude_unset,
+ exclude_defaults=self.exclude_defaults,
+ )
+ updated_yaml = stream.getvalue()
+
+ diff = compare(
+ original_yaml.split("\n"),
+ updated_yaml.split("\n"),
+ diff_format=(
+ "html"
+ if isinstance(self.diff, Path) and self.diff.suffix == ".html"
+ else "unified"
+ ),
+ )
+
+ if isinstance(self.diff, Path):
+ _ = self.diff.write_text(diff, encoding="utf-8")
+ elif self.diff:
+ console = rich.console.Console()
+ diff_md = f"## Diff\n\n````````diff\n{diff}\n````````"
+ console.print(rich.markdown.Markdown(diff_md))
+
+ if isinstance(self.output, Path):
+ _ = self.output.write_text(updated_yaml, encoding="utf-8")
+ logger.info(f"written updated description to {self.output}")
+ elif self.output == "display":
+ updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```"
+ rich.console.Console().print(rich.markdown.Markdown(updated_md))
+ elif self.output == "stdout":
+ print(updated_yaml)
+ else:
+ assert_never(self.output)
+
+ if isinstance(self.updated, InvalidDescr):
+ logger.warning("Update resulted in invalid description")
+ _ = self.updated.validation_summary.display()
+
+
+class UpdateFormatCmd(UpdateCmdBase):
+ """Update the metadata format to the latest format version."""
+
+ exclude_defaults: bool = Field(True, alias="exclude-defaults")
+ """Exclude fields that have the default value (even if set explicitly).
+
+ Note:
+ The update process sets most unset fields explicitly with their default value.
+ """
+
+ perform_io_checks: bool = Field(
+ settings.perform_io_checks, alias="perform-io-checks"
+ )
+ """Wether or not to attempt validation that may require file download.
+ If `True` file hash values are added if not present."""
+
+ @cached_property
+ def updated(self):
+ return update_format(
+ self.source,
+ exclude_defaults=self.exclude_defaults,
+ perform_io_checks=self.perform_io_checks,
+ )
+
+
+class UpdateHashesCmd(UpdateCmdBase):
+ """Create a bioimageio.yaml description with updated file hashes."""
+
+ @cached_property
+ def updated(self):
+ return update_hashes(self.source)
+
+
class PredictCmd(CmdBase, WithSource):
"""Run inference on your data with a bioimage.io model."""
@@ -222,7 +404,7 @@ class PredictCmd(CmdBase, WithSource):
Example inputs to process sample 'a' and 'b'
for a model expecting a 'raw' and a 'mask' input tensor:
- --inputs="[[\"a_raw.tif\",\"a_mask.tif\"],[\"b_raw.tif\",\"b_mask.tif\"]]"
+ --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]"
(Note that JSON double quotes need to be escaped.)
Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file
@@ -270,7 +452,11 @@ class PredictCmd(CmdBase, WithSource):
"""preview which files would be processed
and what outputs would be generated."""
- weight_format: WeightFormatArgAny = "any"
+ weight_format: WeightFormatArgAny = Field(
+ "any",
+ alias="weight-format",
+ validation_alias=WEIGHT_FORMAT_ALIASES,
+ )
"""The weight format to use."""
example: bool = False
@@ -318,13 +504,15 @@ def _example(self):
bioimageio_cli_path = example_path / YAML_FILE
stats_file = "dataset_statistics.json"
stats = (example_path / stats_file).as_posix()
- yaml.dump(
- dict(
- inputs=inputs,
- outputs=output_pattern,
- stats=stats_file,
- blockwise=self.blockwise,
- ),
+ cli_example_args = dict(
+ inputs=inputs,
+ outputs=output_pattern,
+ stats=stats_file,
+ blockwise=self.blockwise,
+ )
+ assert is_yaml_value(cli_example_args)
+ write_yaml(
+ cli_example_args,
bioimageio_cli_path,
)
@@ -545,16 +733,49 @@ def input_dataset(stat: Stat):
save_sample(sp_out, sample_out)
+class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging):
+ output: CliPositionalArg[Path]
+ """The path to write the updated model package to."""
+
+ source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format")
+ """Exclusively use these weights to convert to other formats."""
+
+ target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format")
+ """Exclusively add this weight format."""
+
+ verbose: bool = False
+ """Log more (error) output."""
+
+ def run(self):
+ model_descr = ensure_description_is_model(self.descr)
+ if isinstance(model_descr, v0_4.ModelDescr):
+ raise TypeError(
+ f"model format {model_descr.format_version} not supported."
+ + " Please update the model first."
+ )
+ updated_model_descr = add_weights(
+ model_descr,
+ output_path=self.output,
+ source_format=self.source_format,
+ target_format=self.target_format,
+ verbose=self.verbose,
+ )
+ if updated_model_descr is None:
+ return
+
+ self.log(updated_model_descr)
+
+
JSON_FILE = "bioimageio-cli.json"
YAML_FILE = "bioimageio-cli.yaml"
class Bioimageio(
BaseSettings,
+ cli_implicit_flags=True,
cli_parse_args=True,
cli_prog_name="bioimageio",
cli_use_class_docs_for_groups=True,
- cli_implicit_flags=True,
use_attribute_docstrings=True,
):
"""bioimageio - CLI for bioimage.io resources 🦒"""
@@ -576,6 +797,16 @@ class Bioimageio(
predict: CliSubCommand[PredictCmd]
"Predict with a model resource"
+ update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format")
+ """Update the metadata format"""
+
+ update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes")
+ """Create a bioimageio.yaml description with updated file hashes."""
+
+ add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights")
+ """Add additional weights to the model descriptions converted from available
+ formats to improve deployability."""
+
@classmethod
def settings_customise_sources(
cls,
@@ -613,7 +844,15 @@ def run(self):
"executing CLI command:\n{}",
pformat({k: v for k, v in self.model_dump().items() if v is not None}),
)
- cmd = self.validate_format or self.test or self.package or self.predict
+ cmd = (
+ self.add_weights
+ or self.package
+ or self.predict
+ or self.test
+ or self.update_format
+ or self.update_hashes
+ or self.validate_format
+ )
assert cmd is not None
cmd.run()
diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py
index c71d495f..7184014c 100644
--- a/bioimageio/core/commands.py
+++ b/bioimageio/core/commands.py
@@ -1,4 +1,4 @@
-"""These functions implement the logic of the bioimageio command line interface
+"""These functions are used in the bioimageio command line interface
defined in `bioimageio.core.cli`."""
from pathlib import Path
@@ -6,18 +6,18 @@
from typing_extensions import Literal
+from bioimageio.core.common import SupportedWeightsFormat
from bioimageio.spec import (
InvalidDescr,
ResourceDescr,
save_bioimageio_package,
save_bioimageio_package_as_folder,
)
-from bioimageio.spec.model.v0_5 import WeightsFormat
from ._resource_tests import test_description
-WeightFormatArgAll = Literal[WeightsFormat, "all"]
-WeightFormatArgAny = Literal[WeightsFormat, "any"]
+WeightFormatArgAll = Literal[SupportedWeightsFormat, "all"]
+WeightFormatArgAny = Literal[SupportedWeightsFormat, "any"]
def test(
@@ -25,45 +25,53 @@ def test(
*,
weight_format: WeightFormatArgAll = "all",
devices: Optional[Union[str, Sequence[str]]] = None,
- decimal: int = 4,
+ summary: Union[
+ Literal["display"], Path, Sequence[Union[Literal["display"], Path]]
+ ] = "display",
+ runtime_env: Union[
+ Literal["currently-active", "as-described"], Path
+ ] = "currently-active",
+ determinism: Literal["seed_only", "full"] = "seed_only",
) -> int:
- """test a bioimageio resource
+ """Test a bioimageio resource.
- Args:
- source: Path or URL to the bioimageio resource description file
- (bioimageio.yaml or rdf.yaml) or to a zipped resource
- weight_format: (model only) The weight format to use
- devices: Device(s) to use for testing
- decimal: Precision for numerical comparisons
+ Arguments as described in `bioimageio.core.cli.TestCmd`
"""
if isinstance(descr, InvalidDescr):
- descr.validation_summary.display()
- return 1
+ test_summary = descr.validation_summary
+ else:
+ test_summary = test_description(
+ descr,
+ weight_format=None if weight_format == "all" else weight_format,
+ devices=[devices] if isinstance(devices, str) else devices,
+ runtime_env=runtime_env,
+ determinism=determinism,
+ )
- summary = test_description(
- descr,
- weight_format=None if weight_format == "all" else weight_format,
- devices=[devices] if isinstance(devices, str) else devices,
- decimal=decimal,
- )
- summary.display()
- return 0 if summary.status == "passed" else 1
+ _ = test_summary.log(summary)
+ return 0 if test_summary.status == "passed" else 1
def validate_format(
descr: Union[ResourceDescr, InvalidDescr],
+ summary: Union[Path, Sequence[Path]] = (),
):
- """validate the meta data format of a bioimageio resource
+ """DEPRECATED; Access the existing `validation_summary` attribute instead.
+ validate the meta data format of a bioimageio resource
Args:
descr: a bioimageio resource description
"""
- descr.validation_summary.display()
- return 0 if descr.validation_summary.status == "passed" else 1
+ _ = descr.validation_summary.save(summary)
+ return 0 if descr.validation_summary.status in ("valid-format", "passed") else 1
+# TODO: absorb into `save_bioimageio_package`
def package(
- descr: ResourceDescr, path: Path, *, weight_format: WeightFormatArgAll = "all"
+ descr: ResourceDescr,
+ path: Path,
+ *,
+ weight_format: WeightFormatArgAll = "all",
):
"""Save a resource's metadata with its associated files.
@@ -76,8 +84,12 @@ def package(
weight-format: include only this single weight-format (if not 'all').
"""
if isinstance(descr, InvalidDescr):
- descr.validation_summary.display()
- raise ValueError("resource description is invalid")
+ logged = descr.validation_summary.save()
+ msg = f"Invalid {descr.type} description."
+ if logged:
+ msg += f" Details saved to {logged}."
+
+ raise ValueError(msg)
if weight_format == "all":
weights_priority_order = None
diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py
index 78a85886..9f939061 100644
--- a/bioimageio/core/common.py
+++ b/bioimageio/core/common.py
@@ -15,6 +15,17 @@
from bioimageio.spec.model import v0_5
+from .axis import AxisId
+
+SupportedWeightsFormat = Literal[
+ "keras_hdf5",
+ "onnx",
+ "pytorch_state_dict",
+ "tensorflow_saved_model_bundle",
+ "torchscript",
+]
+
+
DTypeStr = Literal[
"bool",
"float32",
@@ -87,7 +98,19 @@ class SliceInfo(NamedTuple):
SampleId = Hashable
+"""ID of a sample, see `bioimageio.core.sample.Sample`"""
MemberId = v0_5.TensorId
+"""ID of a `Sample` member, see `bioimageio.core.sample.Sample`"""
+
+BlocksizeParameter = Union[
+ v0_5.ParameterizedSize_N,
+ Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
+]
+"""
+Parameter to determine a concrete size for paramtrized axis sizes defined by
+`bioimageio.spec.model.v0_5.ParameterizedSize`.
+"""
+
T = TypeVar("T")
PerMember = Mapping[MemberId, T]
diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py
index edb5a45d..fb0462f5 100644
--- a/bioimageio/core/digest_spec.py
+++ b/bioimageio/core/digest_spec.py
@@ -1,6 +1,9 @@
from __future__ import annotations
+import collections.abc
+import hashlib
import importlib.util
+import sys
from itertools import chain
from pathlib import Path
from typing import (
@@ -23,9 +26,8 @@
from numpy.typing import NDArray
from typing_extensions import Unpack, assert_never
-from bioimageio.spec._internal.io import resolve_and_extract
-from bioimageio.spec._internal.io_utils import HashKwargs
-from bioimageio.spec.common import FileSource
+from bioimageio.spec._internal.io import HashKwargs
+from bioimageio.spec.common import FileDescr, FileSource, ZipPath
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
from bioimageio.spec.model.v0_5 import (
@@ -33,9 +35,10 @@
ArchitectureFromLibraryDescr,
ParameterizedSize_N,
)
-from bioimageio.spec.utils import load_array
+from bioimageio.spec.utils import download, load_array
-from .axis import AxisId, AxisInfo, AxisLike, PerAxis
+from ._settings import settings
+from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
from .block_meta import split_multiple_shapes_into_blocks
from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
from .io import load_tensor
@@ -48,9 +51,16 @@
from .stat_measures import Stat
from .tensor import Tensor
+TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
+
def import_callable(
- node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr],
+ node: Union[
+ ArchitectureFromFileDescr,
+ ArchitectureFromLibraryDescr,
+ CallableFromDepencency,
+ CallableFromFile,
+ ],
/,
**kwargs: Unpack[HashKwargs],
) -> Callable[..., Any]:
@@ -65,7 +75,6 @@ def import_callable(
c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
elif isinstance(node, ArchitectureFromFileDescr):
c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
-
else:
assert_never(node)
@@ -78,17 +87,70 @@ def import_callable(
def _import_from_file_impl(
source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
):
- local_file = resolve_and_extract(source, **kwargs)
- module_name = local_file.path.stem
- importlib_spec = importlib.util.spec_from_file_location(
- module_name, local_file.path
- )
- if importlib_spec is None:
- raise ImportError(f"Failed to import {module_name} from {source}.")
+ src_descr = FileDescr(source=source, **kwargs)
+ # ensure sha is valid even if perform_io_checks=False
+ src_descr.validate_sha256()
+ assert src_descr.sha256 is not None
+
+ local_source = src_descr.download()
+
+ source_bytes = local_source.path.read_bytes()
+ assert isinstance(source_bytes, bytes)
+ source_sha = hashlib.sha256(source_bytes).hexdigest()
+
+ # make sure we have unique module name
+ module_name = f"{local_source.path.stem}_{source_sha}"
+
+ # make sure we have a valid module name
+ if not module_name.isidentifier():
+ module_name = f"custom_module_{source_sha}"
+ assert module_name.isidentifier(), module_name
+
+ module = sys.modules.get(module_name)
+ if module is None:
+ try:
+ if isinstance(local_source.path, Path):
+ module_path = local_source.path
+ elif isinstance(local_source.path, ZipPath):
+ # save extract source to cache
+ # loading from a file from disk ensure we get readable tracebacks
+ # if any errors occur
+ module_path = (
+ settings.cache_path / f"{source_sha}-{local_source.path.name}"
+ )
+ _ = module_path.write_bytes(source_bytes)
+ else:
+ assert_never(local_source.path)
+
+ importlib_spec = importlib.util.spec_from_file_location(
+ module_name, module_path
+ )
- dep = importlib.util.module_from_spec(importlib_spec)
- importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"?
- return getattr(dep, callable_name)
+ if importlib_spec is None:
+ raise ImportError(f"Failed to import {source}")
+
+ module = importlib.util.module_from_spec(importlib_spec)
+ assert importlib_spec.loader is not None
+ importlib_spec.loader.exec_module(module)
+
+ except Exception as e:
+ raise ImportError(f"Failed to import {source}") from e
+ else:
+ sys.modules[module_name] = module # cache this module
+
+ try:
+ callable_attr = getattr(module, callable_name)
+ except AttributeError as e:
+ raise AttributeError(
+ f"Imported custom module from {source} has no `{callable_name}` attribute."
+ ) from e
+ except Exception as e:
+ raise AttributeError(
+ f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
+ ) from e
+
+ else:
+ return callable_attr
def get_axes_infos(
@@ -100,14 +162,15 @@ def get_axes_infos(
],
) -> List[AxisInfo]:
"""get a unified, simplified axis representation from spec axes"""
- return [
- (
- AxisInfo.create("i")
- if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x")
- else AxisInfo.create(a)
- )
- for a in io_descr.axes
- ]
+ ret: List[AxisInfo] = []
+ for a in io_descr.axes:
+ if isinstance(a, v0_5.AxisBase):
+ ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
+ else:
+ assert a in ("b", "i", "t", "c", "z", "y", "x")
+ ret.append(AxisInfo.create(a))
+
+ return ret
def get_member_id(
@@ -308,7 +371,7 @@ def get_io_sample_block_metas(
def get_tensor(
- src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
+ src: Union[ZipPath, TensorSource],
ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
):
"""helper to cast/load various tensor sources"""
@@ -322,7 +385,10 @@ def get_tensor(
if isinstance(src, np.ndarray):
return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
- if isinstance(src, Path):
+ if isinstance(src, FileDescr):
+ src = download(src).path
+
+ if isinstance(src, (ZipPath, Path, str)):
return load_tensor(src, axes=get_axes_infos(ipt))
assert_never(src)
@@ -333,10 +399,7 @@ def create_sample_for_model(
*,
stat: Optional[Stat] = None,
sample_id: SampleId = None,
- inputs: Optional[
- PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]
- ] = None, # TODO: make non-optional
- **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs`
+ inputs: Union[PerMember[TensorSource], TensorSource],
) -> Sample:
"""Create a sample from a single set of input(s) for a specific bioimage.io model
@@ -345,9 +408,17 @@ def create_sample_for_model(
stat: dictionary with sample and dataset statistics (may be updated in-place!)
inputs: the input(s) constituting a single sample.
"""
- inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()}
model_inputs = {get_member_id(d): d for d in model.inputs}
+ if isinstance(inputs, collections.abc.Mapping):
+ inputs = {MemberId(k): v for k, v in inputs.items()}
+ elif len(model_inputs) == 1:
+ inputs = {list(model_inputs)[0]: inputs}
+ else:
+ raise TypeError(
+ f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
+ )
+
if unknown := {k for k in inputs if k not in model_inputs}:
raise ValueError(f"Got unexpected inputs: {unknown}")
diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py
index ee60a67a..dc5b70db 100644
--- a/bioimageio/core/io.py
+++ b/bioimageio/core/io.py
@@ -1,16 +1,35 @@
import collections.abc
import warnings
+import zipfile
+from io import TextIOWrapper
from pathlib import Path, PurePosixPath
-from typing import Any, Mapping, Optional, Sequence, Tuple, Union
+from shutil import copyfileobj
+from typing import (
+ Any,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+)
-import h5py
+import h5py # pyright: ignore[reportMissingTypeStubs]
import numpy as np
-from imageio.v3 import imread, imwrite
+from imageio.v3 import imread, imwrite # type: ignore
from loguru import logger
from numpy.typing import NDArray
from pydantic import BaseModel, ConfigDict, TypeAdapter
-
-from bioimageio.spec.utils import load_array, save_array
+from typing_extensions import assert_never
+
+from bioimageio.spec._internal.io import interprete_file_source
+from bioimageio.spec.common import (
+ HttpUrl,
+ PermissiveFileSource,
+ RelativeFilePath,
+ ZipPath,
+)
+from bioimageio.spec.utils import download, load_array, save_array
from .axis import AxisLike
from .common import PerMember
@@ -21,29 +40,54 @@
DEFAULT_H5_DATASET_PATH = "data"
-def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]:
+SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5")
+
+
+def load_image(
+ source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None
+) -> NDArray[Any]:
"""load a single image as numpy array
Args:
- path: image path
+ source: image source
is_volume: deprecated
"""
if is_volume is not None:
warnings.warn("**is_volume** is deprecated and will be removed soon.")
- file_path, subpath = _split_dataset_path(Path(path))
+ if isinstance(source, ZipPath):
+ parsed_source = source
+ else:
+ parsed_source = interprete_file_source(source)
- if file_path.suffix == ".npy":
+ if isinstance(parsed_source, RelativeFilePath):
+ src = parsed_source.absolute()
+ else:
+ src = parsed_source
+
+ # FIXME: why is pyright complaining about giving the union to _split_dataset_path?
+ if isinstance(src, Path):
+ file_source, subpath = _split_dataset_path(src)
+ elif isinstance(src, HttpUrl):
+ file_source, subpath = _split_dataset_path(src)
+ elif isinstance(src, ZipPath):
+ file_source, subpath = _split_dataset_path(src)
+ else:
+ assert_never(src)
+
+ path = download(file_source).path
+
+ if path.suffix == ".npy":
if subpath is not None:
raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}")
return load_array(path)
- elif file_path.suffix in (".h5", ".hdf", ".hdf5"):
+ elif path.suffix in SUFFIXES_WITH_DATAPATH:
if subpath is None:
dataset_path = DEFAULT_H5_DATASET_PATH
else:
dataset_path = str(subpath)
- with h5py.File(file_path, "r") as f:
+ with h5py.File(path, "r") as f:
h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType]
dataset_path
)
@@ -60,41 +104,81 @@ def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]:
image # pyright: ignore[reportUnknownArgumentType]
)
return image # pyright: ignore[reportUnknownVariableType]
+ elif isinstance(path, ZipPath):
+ return imread(
+ path.read_bytes(), extension=path.suffix
+ ) # pyright: ignore[reportUnknownVariableType]
else:
return imread(path) # pyright: ignore[reportUnknownVariableType]
-def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor:
+def load_tensor(
+ path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None
+) -> Tensor:
# TODO: load axis meta data
array = load_image(path)
return Tensor.from_numpy(array, dims=axes)
-def _split_dataset_path(path: Path) -> Tuple[Path, Optional[PurePosixPath]]:
+_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
+
+
+def _split_dataset_path(
+ source: _SourceT,
+) -> Tuple[_SourceT, Optional[PurePosixPath]]:
"""Split off subpath (e.g. internal h5 dataset path)
from a file path following a file extension.
Examples:
>>> _split_dataset_path(Path("my_file.h5/dataset"))
- (PosixPath('my_file.h5'), PurePosixPath('dataset'))
+ (...Path('my_file.h5'), PurePosixPath('dataset'))
- If no suffix is detected the path is returned with
>>> _split_dataset_path(Path("my_plain_file"))
- (PosixPath('my_plain_file'), None)
+ (...Path('my_plain_file'), None)
"""
- if path.suffix:
+ if isinstance(source, RelativeFilePath):
+ src = source.absolute()
+ else:
+ src = source
+
+ del source
+
+ def separate_pure_path(path: PurePosixPath):
+ for p in path.parents:
+ if p.suffix in SUFFIXES_WITH_DATAPATH:
+ return p, PurePosixPath(path.relative_to(p))
+
return path, None
- for p in path.parents:
- if p.suffix:
- return p, PurePosixPath(path.relative_to(p))
+ if isinstance(src, HttpUrl):
+ file_path, data_path = separate_pure_path(PurePosixPath(src.path or ""))
- return path, None
+ if data_path is None:
+ return src, None
+ return (
+ HttpUrl(str(file_path).replace(f"/{data_path}", "")),
+ data_path,
+ )
+
+ if isinstance(src, ZipPath):
+ file_path, data_path = separate_pure_path(PurePosixPath(str(src)))
+
+ if data_path is None:
+ return src, None
+
+ return (
+ ZipPath(str(file_path).replace(f"/{data_path}", "")),
+ data_path,
+ )
-def save_tensor(path: Path, tensor: Tensor) -> None:
+ file_path, data_path = separate_pure_path(PurePosixPath(src))
+ return Path(file_path), data_path
+
+
+def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
# TODO: save axis meta data
data: NDArray[Any] = tensor.data.to_numpy()
@@ -134,23 +218,33 @@ def save_tensor(path: Path, tensor: Tensor) -> None:
imwrite(path, data)
-def save_sample(path: Union[Path, str, PerMember[Path]], sample: Sample) -> None:
- """save a sample to path
-
- If `path` is a pathlib.Path or a string it must contain `{member_id}` and may contain `{sample_id}`,
- which are resolved with the `sample` object.
- """
+def save_sample(
+ path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
+) -> None:
+ """Save a **sample** to a **path** pattern
+ or all sample members in the **path** mapping.
- if not isinstance(path, collections.abc.Mapping) and "{member_id}" not in str(path):
- raise ValueError(f"missing `{{member_id}}` in path {path}")
+ If **path** is a pathlib.Path or a string and the **sample** has multiple members,
+ **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
- for m, t in sample.members.items():
- if isinstance(path, collections.abc.Mapping):
- p = path[m]
+ (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
+ """
+ if not isinstance(path, collections.abc.Mapping):
+ if len(sample.members) < 2 or any(
+ m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
+ ):
+ path = {m: path for m in sample.members}
else:
- p = Path(str(path).format(sample_id=sample.id, member_id=m))
+ raise ValueError(
+ f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
+ )
- save_tensor(p, t)
+ for m, p in path.items():
+ t = sample.members[m]
+ p_formatted = Path(
+ str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
+ )
+ save_tensor(p_formatted, t)
class _SerializedDatasetStatsEntry(
@@ -176,3 +270,27 @@ def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
def load_dataset_stat(path: Path):
seq = _stat_adapter.validate_json(path.read_bytes())
return {e.measure: e.value for e in seq}
+
+
+def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path):
+ """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive.
+ Always returns the path to the unzipped source (maybe source itself)"""
+ local_weights_file = download(source).path
+ if isinstance(local_weights_file, ZipPath):
+ # source is inside a zip archive
+ out_path = folder / local_weights_file.filename
+ with local_weights_file.open("rb") as src, out_path.open("wb") as dst:
+ assert not isinstance(src, TextIOWrapper)
+ copyfileobj(src, dst)
+
+ local_weights_file = out_path
+
+ if zipfile.is_zipfile(local_weights_file):
+ # source itself is a zipfile
+ out_path = folder / local_weights_file.with_suffix(".unzipped").name
+ with zipfile.ZipFile(local_weights_file, "r") as f:
+ f.extractall(out_path)
+
+ return out_path
+ else:
+ return local_weights_file
diff --git a/bioimageio/core/model_adapters.py b/bioimageio/core/model_adapters.py
new file mode 100644
index 00000000..db92d013
--- /dev/null
+++ b/bioimageio/core/model_adapters.py
@@ -0,0 +1,22 @@
+"""DEPRECATED"""
+
+from typing import List
+
+from .backends._model_adapter import (
+ DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER,
+ ModelAdapter,
+ create_model_adapter,
+)
+
+__all__ = [
+ "ModelAdapter",
+ "create_model_adapter",
+ "get_weight_formats",
+]
+
+
+def get_weight_formats() -> List[str]:
+ """
+ Return list of supported weight types
+ """
+ return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
diff --git a/bioimageio/core/model_adapters/__init__.py b/bioimageio/core/model_adapters/__init__.py
deleted file mode 100644
index 01899de9..00000000
--- a/bioimageio/core/model_adapters/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from ._model_adapter import ModelAdapter, create_model_adapter, get_weight_formats
-
-__all__ = [
- "ModelAdapter",
- "create_model_adapter",
- "get_weight_formats",
-]
diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py
deleted file mode 100644
index c918603e..00000000
--- a/bioimageio/core/model_adapters/_model_adapter.py
+++ /dev/null
@@ -1,177 +0,0 @@
-import warnings
-from abc import ABC, abstractmethod
-from typing import List, Optional, Sequence, Tuple, Union, final
-
-from bioimageio.spec.model import v0_4, v0_5
-
-from ..tensor import Tensor
-
-WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat]
-
-# Known weight formats in order of priority
-# First match wins
-DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = (
- "pytorch_state_dict",
- "tensorflow_saved_model_bundle",
- "torchscript",
- "onnx",
- "keras_hdf5",
-)
-
-
-class ModelAdapter(ABC):
- """
- Represents model *without* any preprocessing or postprocessing.
-
- ```
- from bioimageio.core import load_description
-
- model = load_description(...)
-
- # option 1:
- adapter = ModelAdapter.create(model)
- adapter.forward(...)
- adapter.unload()
-
- # option 2:
- with ModelAdapter.create(model) as adapter:
- adapter.forward(...)
- ```
- """
-
- @final
- @classmethod
- def create(
- cls,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- *,
- devices: Optional[Sequence[str]] = None,
- weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
- ):
- """
- Creates model adapter based on the passed spec
- Note: All specific adapters should happen inside this function to prevent different framework
- initializations interfering with each other
- """
- if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
- raise TypeError(
- f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
- )
-
- weights = model_description.weights
- errors: List[Tuple[WeightsFormat, Exception]] = []
- weight_format_priority_order = (
- DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
- if weight_format_priority_order is None
- else weight_format_priority_order
- )
- # limit weight formats to the ones present
- weight_format_priority_order = [
- w for w in weight_format_priority_order if getattr(weights, w) is not None
- ]
-
- for wf in weight_format_priority_order:
- if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
- try:
- from ._pytorch_model_adapter import PytorchModelAdapter
-
- return PytorchModelAdapter(
- outputs=model_description.outputs,
- weights=weights.pytorch_state_dict,
- devices=devices,
- )
- except Exception as e:
- errors.append((wf, e))
- elif (
- wf == "tensorflow_saved_model_bundle"
- and weights.tensorflow_saved_model_bundle is not None
- ):
- try:
- from ._tensorflow_model_adapter import TensorflowModelAdapter
-
- return TensorflowModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append((wf, e))
- elif wf == "onnx" and weights.onnx is not None:
- try:
- from ._onnx_model_adapter import ONNXModelAdapter
-
- return ONNXModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append((wf, e))
- elif wf == "torchscript" and weights.torchscript is not None:
- try:
- from ._torchscript_model_adapter import TorchscriptModelAdapter
-
- return TorchscriptModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append((wf, e))
- elif wf == "keras_hdf5" and weights.keras_hdf5 is not None:
- # keras can either be installed as a separate package or used as part of tensorflow
- # we try to first import the keras model adapter using the separate package and,
- # if it is not available, try to load the one using tf
- try:
- from ._keras_model_adapter import (
- KerasModelAdapter,
- keras, # type: ignore
- )
-
- if keras is None:
- from ._tensorflow_model_adapter import KerasModelAdapter
-
- return KerasModelAdapter(
- model_description=model_description, devices=devices
- )
- except Exception as e:
- errors.append((wf, e))
-
- assert errors
- if len(weight_format_priority_order) == 1:
- assert len(errors) == 1
- raise ValueError(
- f"The '{weight_format_priority_order[0]}' model adapter could not be created"
- + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
- )
-
- else:
- error_list = "\n - ".join(
- f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
- )
- raise ValueError(
- "None of the weight format specific model adapters could be created"
- + f" in this environment. Errors are:\n\n{error_list}.\n\n"
- )
-
- @final
- def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
- warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
-
- @abstractmethod
- def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
- """
- Run forward pass of model to get model predictions
- """
- # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl
-
- @abstractmethod
- def unload(self):
- """
- Unload model from any devices, freeing their memory.
- The moder adapter should be considered unusable afterwards.
- """
-
-
-def get_weight_formats() -> List[str]:
- """
- Return list of supported weight types
- """
- return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
-
-
-create_model_adapter = ModelAdapter.create
diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py
deleted file mode 100644
index c747de22..00000000
--- a/bioimageio/core/model_adapters/_onnx_model_adapter.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import warnings
-from typing import Any, List, Optional, Sequence, Union
-
-from numpy.typing import NDArray
-
-from bioimageio.spec.model import v0_4, v0_5
-from bioimageio.spec.utils import download
-
-from ..digest_spec import get_axes_infos
-from ..tensor import Tensor
-from ._model_adapter import ModelAdapter
-
-try:
- import onnxruntime as rt
-except Exception as e:
- rt = None
- rt_error = str(e)
-else:
- rt_error = None
-
-
-class ONNXModelAdapter(ModelAdapter):
- def __init__(
- self,
- *,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- devices: Optional[Sequence[str]] = None,
- ):
- if rt is None:
- raise ImportError(f"failed to import onnxruntime: {rt_error}")
-
- super().__init__()
- self._internal_output_axes = [
- tuple(a.id for a in get_axes_infos(out))
- for out in model_description.outputs
- ]
- if model_description.weights.onnx is None:
- raise ValueError("No ONNX weights specified for {model_description.name}")
-
- self._session = rt.InferenceSession(
- str(download(model_description.weights.onnx.source).path)
- )
- onnx_inputs = self._session.get_inputs() # type: ignore
- self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore
-
- if devices is not None:
- warnings.warn(
- f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
- )
-
- def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
- assert len(input_tensors) == len(self._input_names)
- input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors]
- result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]]
- result = self._session.run( # pyright: ignore[reportUnknownVariableType]
- None, dict(zip(self._input_names, input_arrays))
- )
- if isinstance(result, (list, tuple)):
- result_seq: Sequence[Optional[NDArray[Any]]] = result
- else:
- result_seq = [result] # type: ignore
-
- return [
- None if r is None else Tensor(r, dims=axes)
- for r, axes in zip(result_seq, self._internal_output_axes)
- ]
-
- def unload(self) -> None:
- warnings.warn(
- "Device management is not implemented for onnx yet, cannot unload model"
- )
diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py
deleted file mode 100644
index a5178d74..00000000
--- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py
+++ /dev/null
@@ -1,153 +0,0 @@
-import gc
-import warnings
-from typing import Any, List, Optional, Sequence, Tuple, Union
-
-from bioimageio.spec.model import v0_4, v0_5
-from bioimageio.spec.utils import download
-
-from ..axis import AxisId
-from ..digest_spec import get_axes_infos, import_callable
-from ..tensor import Tensor
-from ._model_adapter import ModelAdapter
-
-try:
- import torch
-except Exception as e:
- torch = None
- torch_error = str(e)
-else:
- torch_error = None
-
-
-class PytorchModelAdapter(ModelAdapter):
- def __init__(
- self,
- *,
- outputs: Union[
- Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr]
- ],
- weights: Union[
- v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
- ],
- devices: Optional[Sequence[str]] = None,
- ):
- if torch is None:
- raise ImportError(f"failed to import torch: {torch_error}")
-
- super().__init__()
- self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs]
- self._network = self.get_network(weights)
- self._devices = self.get_devices(devices)
- self._network = self._network.to(self._devices[0])
-
- self._primary_device = self._devices[0]
- state: Any = torch.load(
- download(weights).path,
- map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType]
- )
- self._network.load_state_dict(state)
-
- self._network = self._network.eval()
-
- def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
- if torch is None:
- raise ImportError("torch")
- with torch.no_grad():
- tensors = [
- None if ipt is None else torch.from_numpy(ipt.data.data)
- for ipt in input_tensors
- ]
- tensors = [
- (
- None
- if t is None
- else t.to(
- self._primary_device # pyright: ignore[reportUnknownArgumentType]
- )
- )
- for t in tensors
- ]
- result: Union[Tuple[Any, ...], List[Any], Any]
- result = self._network( # pyright: ignore[reportUnknownVariableType]
- *tensors
- )
- if not isinstance(result, (tuple, list)):
- result = [result]
-
- result = [
- (
- None
- if r is None
- else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r
- )
- for r in result # pyright: ignore[reportUnknownVariableType]
- ]
- if len(result) > len(self.output_dims):
- raise ValueError(
- f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}"
- )
-
- return [
- None if r is None else Tensor(r, dims=out)
- for r, out in zip(result, self.output_dims)
- ]
-
- def unload(self) -> None:
- del self._network
- _ = gc.collect() # deallocate memory
- assert torch is not None
- torch.cuda.empty_cache() # release reserved memory
-
- @staticmethod
- def get_network( # pyright: ignore[reportUnknownParameterType]
- weight_spec: Union[
- v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
- ],
- ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm]
- if torch is None:
- raise ImportError("torch")
- arch = import_callable(
- weight_spec.architecture,
- sha256=(
- weight_spec.architecture_sha256
- if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
- else weight_spec.sha256
- ),
- )
- model_kwargs = (
- weight_spec.kwargs
- if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
- else weight_spec.architecture.kwargs
- )
- network = arch(**model_kwargs)
- if not isinstance(network, torch.nn.Module):
- raise ValueError(
- f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module"
- )
-
- return network
-
- @staticmethod
- def get_devices( # pyright: ignore[reportUnknownParameterType]
- devices: Optional[Sequence[str]] = None,
- ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm]
- if torch is None:
- raise ImportError("torch")
- if not devices:
- torch_devices = [
- (
- torch.device("cuda")
- if torch.cuda.is_available()
- else torch.device("cpu")
- )
- ]
- else:
- torch_devices = [torch.device(d) for d in devices]
-
- if len(torch_devices) > 1:
- warnings.warn(
- f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}"
- )
- torch_devices = torch_devices[:1]
-
- return torch_devices
diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py
deleted file mode 100644
index cfb264f0..00000000
--- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py
+++ /dev/null
@@ -1,275 +0,0 @@
-import zipfile
-from typing import List, Literal, Optional, Sequence, Union
-
-import numpy as np
-from loguru import logger
-
-from bioimageio.spec.common import FileSource
-from bioimageio.spec.model import v0_4, v0_5
-from bioimageio.spec.utils import download
-
-from ..digest_spec import get_axes_infos
-from ..tensor import Tensor
-from ._model_adapter import ModelAdapter
-
-try:
- import tensorflow as tf # pyright: ignore[reportMissingImports]
-except Exception as e:
- tf = None
- tf_error = str(e)
-else:
- tf_error = None
-
-
-class TensorflowModelAdapterBase(ModelAdapter):
- weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"]
-
- def __init__(
- self,
- *,
- devices: Optional[Sequence[str]] = None,
- weights: Union[
- v0_4.KerasHdf5WeightsDescr,
- v0_4.TensorflowSavedModelBundleWeightsDescr,
- v0_5.KerasHdf5WeightsDescr,
- v0_5.TensorflowSavedModelBundleWeightsDescr,
- ],
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- ):
- if tf is None:
- raise ImportError(f"failed to import tensorflow: {tf_error}")
-
- super().__init__()
- self.model_description = model_description
- tf_version = v0_5.Version(
- tf.__version__ # pyright: ignore[reportUnknownArgumentType]
- )
- model_tf_version = weights.tensorflow_version
- if model_tf_version is None:
- logger.warning(
- "The model does not specify the tensorflow version."
- + f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
- )
- elif model_tf_version > tf_version:
- logger.warning(
- f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
- )
- elif (model_tf_version.major, model_tf_version.minor) != (
- tf_version.major,
- tf_version.minor,
- ):
- logger.warning(
- "The tensorflow version specified by the model does not match the installed: "
- + f"{model_tf_version} != {tf_version}."
- )
-
- self.use_keras_api = (
- tf_version.major > 1
- or self.weight_format == KerasModelAdapter.weight_format
- )
-
- # TODO tf device management
- if devices is not None:
- logger.warning(
- f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
- )
-
- weight_file = self.require_unzipped(weights.source)
- self._network = self._get_network(weight_file)
- self._internal_output_axes = [
- tuple(a.id for a in get_axes_infos(out))
- for out in model_description.outputs
- ]
-
- def require_unzipped(self, weight_file: FileSource):
- loacl_weights_file = download(weight_file).path
- if zipfile.is_zipfile(loacl_weights_file):
- out_path = loacl_weights_file.with_suffix(".unzipped")
- with zipfile.ZipFile(loacl_weights_file, "r") as f:
- f.extractall(out_path)
-
- return out_path
- else:
- return loacl_weights_file
-
- def _get_network( # pyright: ignore[reportUnknownParameterType]
- self, weight_file: FileSource
- ):
- weight_file = self.require_unzipped(weight_file)
- assert tf is not None
- if self.use_keras_api:
- try:
- return tf.keras.layers.TFSMLayer(
- weight_file, call_endpoint="serve"
- ) # pyright: ignore[reportUnknownVariableType]
- except Exception as e:
- try:
- return tf.keras.layers.TFSMLayer(
- weight_file, call_endpoint="serving_default"
- ) # pyright: ignore[reportUnknownVariableType]
- except Exception as ee:
- logger.opt(exception=ee).info(
- "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
- )
- raise e
- else:
- # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model
- return str(weight_file)
-
- # TODO currently we relaod the model every time. it would be better to keep the graph and session
- # alive in between of forward passes (but then the sessions need to be properly opened / closed)
- def _forward_tf( # pyright: ignore[reportUnknownParameterType]
- self, *input_tensors: Optional[Tensor]
- ):
- assert tf is not None
- input_keys = [
- ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id
- for ipt in self.model_description.inputs
- ]
- output_keys = [
- out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id
- for out in self.model_description.outputs
- ]
- # TODO read from spec
- tag = ( # pyright: ignore[reportUnknownVariableType]
- tf.saved_model.tag_constants.SERVING
- )
- signature_key = ( # pyright: ignore[reportUnknownVariableType]
- tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- )
-
- graph = tf.Graph() # pyright: ignore[reportUnknownVariableType]
- with graph.as_default():
- with tf.Session(
- graph=graph
- ) as sess: # pyright: ignore[reportUnknownVariableType]
- # load the model and the signature
- graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType]
- sess, [tag], self._network
- )
- signature = ( # pyright: ignore[reportUnknownVariableType]
- graph_def.signature_def
- )
-
- # get the tensors into the graph
- in_names = [ # pyright: ignore[reportUnknownVariableType]
- signature[signature_key].inputs[key].name for key in input_keys
- ]
- out_names = [ # pyright: ignore[reportUnknownVariableType]
- signature[signature_key].outputs[key].name for key in output_keys
- ]
- in_tensors = [ # pyright: ignore[reportUnknownVariableType]
- graph.get_tensor_by_name(name)
- for name in in_names # pyright: ignore[reportUnknownVariableType]
- ]
- out_tensors = [ # pyright: ignore[reportUnknownVariableType]
- graph.get_tensor_by_name(name)
- for name in out_names # pyright: ignore[reportUnknownVariableType]
- ]
-
- # run prediction
- res = sess.run( # pyright: ignore[reportUnknownVariableType]
- dict(
- zip(
- out_names, # pyright: ignore[reportUnknownArgumentType]
- out_tensors, # pyright: ignore[reportUnknownArgumentType]
- )
- ),
- dict(
- zip(
- in_tensors, # pyright: ignore[reportUnknownArgumentType]
- input_tensors,
- )
- ),
- )
- # from dict to list of tensors
- res = [ # pyright: ignore[reportUnknownVariableType]
- res[out]
- for out in out_names # pyright: ignore[reportUnknownVariableType]
- ]
-
- return res # pyright: ignore[reportUnknownVariableType]
-
- def _forward_keras( # pyright: ignore[reportUnknownParameterType]
- self, *input_tensors: Optional[Tensor]
- ):
- assert self.use_keras_api
- assert not isinstance(self._network, str)
- assert tf is not None
- tf_tensor = [ # pyright: ignore[reportUnknownVariableType]
- None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors
- ]
-
- result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
-
- assert isinstance(result, dict)
-
- # TODO: Use RDF's `outputs[i].id` here
- result = list(result.values())
-
- return [ # pyright: ignore[reportUnknownVariableType]
- (None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
- for r in result # pyright: ignore[reportUnknownVariableType]
- ]
-
- def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
- data = [None if ipt is None else ipt.data for ipt in input_tensors]
- if self.use_keras_api:
- result = self._forward_keras( # pyright: ignore[reportUnknownVariableType]
- *data
- )
- else:
- result = self._forward_tf( # pyright: ignore[reportUnknownVariableType]
- *data
- )
-
- return [
- None if r is None else Tensor(r, dims=axes)
- for r, axes in zip( # pyright: ignore[reportUnknownVariableType]
- result, # pyright: ignore[reportUnknownArgumentType]
- self._internal_output_axes,
- )
- ]
-
- def unload(self) -> None:
- logger.warning(
- "Device management is not implemented for keras yet, cannot unload model"
- )
-
-
-class TensorflowModelAdapter(TensorflowModelAdapterBase):
- weight_format = "tensorflow_saved_model_bundle"
-
- def __init__(
- self,
- *,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- devices: Optional[Sequence[str]] = None,
- ):
- if model_description.weights.tensorflow_saved_model_bundle is None:
- raise ValueError("missing tensorflow_saved_model_bundle weights")
-
- super().__init__(
- devices=devices,
- weights=model_description.weights.tensorflow_saved_model_bundle,
- model_description=model_description,
- )
-
-
-class KerasModelAdapter(TensorflowModelAdapterBase):
- weight_format = "keras_hdf5"
-
- def __init__(
- self,
- *,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- devices: Optional[Sequence[str]] = None,
- ):
- if model_description.weights.keras_hdf5 is None:
- raise ValueError("missing keras_hdf5 weights")
-
- super().__init__(
- model_description=model_description,
- devices=devices,
- weights=model_description.weights.keras_hdf5,
- )
diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py
deleted file mode 100644
index 0e9f3aef..00000000
--- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import gc
-import warnings
-from typing import Any, List, Optional, Sequence, Tuple, Union
-
-import numpy as np
-from numpy.typing import NDArray
-
-from bioimageio.spec.model import v0_4, v0_5
-from bioimageio.spec.utils import download
-
-from ..digest_spec import get_axes_infos
-from ..tensor import Tensor
-from ._model_adapter import ModelAdapter
-
-try:
- import torch
-except Exception as e:
- torch = None
- torch_error = str(e)
-else:
- torch_error = None
-
-
-class TorchscriptModelAdapter(ModelAdapter):
- def __init__(
- self,
- *,
- model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- devices: Optional[Sequence[str]] = None,
- ):
- if torch is None:
- raise ImportError(f"failed to import torch: {torch_error}")
-
- super().__init__()
- if model_description.weights.torchscript is None:
- raise ValueError(
- f"No torchscript weights found for model {model_description.name}"
- )
-
- weight_path = download(model_description.weights.torchscript.source).path
- if devices is None:
- self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
- else:
- self.devices = [torch.device(d) for d in devices]
-
- if len(self.devices) > 1:
- warnings.warn(
- "Multiple devices for single torchscript model not yet implemented"
- )
-
- self._model = torch.jit.load(weight_path)
- self._model.to(self.devices[0])
- self._model = self._model.eval()
- self._internal_output_axes = [
- tuple(a.id for a in get_axes_infos(out))
- for out in model_description.outputs
- ]
-
- def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]:
- assert torch is not None
- with torch.no_grad():
- torch_tensor = [
- None if b is None else torch.from_numpy(b.data.data).to(self.devices[0])
- for b in batch
- ]
- _result: Union[ # pyright: ignore[reportUnknownVariableType]
- Tuple[Optional[NDArray[Any]], ...],
- List[Optional[NDArray[Any]]],
- Optional[NDArray[Any]],
- ] = self._model.forward(*torch_tensor)
- if isinstance(_result, (tuple, list)):
- result: Sequence[Optional[NDArray[Any]]] = _result
- else:
- result = [_result]
-
- result = [
- (
- None
- if r is None
- else r.cpu().numpy() if not isinstance(r, np.ndarray) else r
- )
- for r in result
- ]
-
- assert len(result) == len(self._internal_output_axes)
- return [
- None if r is None else Tensor(r, dims=axes)
- for r, axes in zip(result, self._internal_output_axes)
- ]
-
- def unload(self) -> None:
- assert torch is not None
- self._devices = None
- del self._model
- _ = gc.collect() # deallocate memory
- torch.cuda.empty_cache() # release reserved memory
diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py
index 27a4129c..9fe5ce12 100644
--- a/bioimageio/core/prediction.py
+++ b/bioimageio/core/prediction.py
@@ -1,7 +1,6 @@
import collections.abc
from pathlib import Path
from typing import (
- Any,
Hashable,
Iterable,
Iterator,
@@ -11,9 +10,7 @@
Union,
)
-import xarray as xr
from loguru import logger
-from numpy.typing import NDArray
from tqdm import tqdm
from bioimageio.spec import load_description
@@ -22,11 +19,10 @@
from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
from .axis import AxisId
-from .common import MemberId, PerMember
-from .digest_spec import create_sample_for_model
+from .common import BlocksizeParameter, MemberId, PerMember
+from .digest_spec import TensorSource, create_sample_for_model, get_member_id
from .io import save_sample
from .sample import Sample
-from .tensor import Tensor
def predict(
@@ -34,14 +30,9 @@ def predict(
model: Union[
PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
],
- inputs: Union[Sample, PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]],
+ inputs: Union[Sample, PerMember[TensorSource], TensorSource],
sample_id: Hashable = "sample",
- blocksize_parameter: Optional[
- Union[
- v0_5.ParameterizedSize_N,
- Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
- ]
- ] = None,
+ blocksize_parameter: Optional[BlocksizeParameter] = None,
input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None,
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
@@ -50,29 +41,31 @@ def predict(
"""Run prediction for a single set of input(s) with a bioimage.io model
Args:
- model: model to predict with.
+ model: Model to predict with.
May be given as RDF source, model description or prediction pipeline.
inputs: the input sample or the named input(s) for this model as a dictionary
sample_id: the sample id.
- blocksize_parameter: (optional) tile the input into blocks parametrized by
- blocksize according to any parametrized axis sizes defined in the model RDF.
- Note: For a predetermined, fixed block shape use `input_block_shape`
- input_block_shape: (optional) tile the input sample tensors into blocks.
- Note: For a parameterized block shape, not dealing with the exact block shape,
- use `blocksize_parameter`.
- skip_preprocessing: flag to skip the model's preprocessing
- skip_postprocessing: flag to skip the model's postprocessing
- save_output_path: A path with `{member_id}` `{sample_id}` in it
- to save the output to.
+ The **sample_id** is used to format **save_output_path**
+ and to distinguish sample specific log messages.
+ blocksize_parameter: (optional) Tile the input into blocks parametrized by
+ **blocksize_parameter** according to any parametrized axis sizes defined
+ by the **model**.
+ See `bioimageio.spec.model.v0_5.ParameterizedSize` for details.
+ Note: For a predetermined, fixed block shape use **input_block_shape**.
+ input_block_shape: (optional) Tile the input sample tensors into blocks.
+ Note: Use **blocksize_parameter** for a parameterized block shape to
+ run prediction independent of the exact block shape.
+ skip_preprocessing: Flag to skip the model's preprocessing.
+ skip_postprocessing: Flag to skip the model's postprocessing.
+ save_output_path: A path with to save the output to. M
+ Must contain:
+ - `{output_id}` (or `{member_id}`) if the model has multiple output tensors
+ May contain:
+ - `{sample_id}` to avoid overwriting recurrent calls
"""
- if save_output_path is not None:
- if "{member_id}" not in str(save_output_path):
- raise ValueError(
- f"Missing `{{member_id}}` in save_output_path={save_output_path}"
- )
-
if isinstance(model, PredictionPipeline):
pp = model
+ model = pp.model_description
else:
if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
loaded = load_description(model)
@@ -82,6 +75,18 @@ def predict(
pp = create_prediction_pipeline(model)
+ if save_output_path is not None:
+ if (
+ "{output_id}" not in str(save_output_path)
+ and "{member_id}" not in str(save_output_path)
+ and len(model.outputs) > 1
+ ):
+ raise ValueError(
+ f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
+ + "distinguish model outputs "
+ + str([get_member_id(d) for d in model.outputs])
+ )
+
if isinstance(inputs, Sample):
sample = inputs
else:
@@ -127,7 +132,7 @@ def predict_many(
model: Union[
PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
],
- inputs: Iterable[PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]],
+ inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]],
sample_id: str = "sample{i:03}",
blocksize_parameter: Optional[
Union[
@@ -142,31 +147,27 @@ def predict_many(
"""Run prediction for a multiple sets of inputs with a bioimage.io model
Args:
- model: model to predict with.
+ model: Model to predict with.
May be given as RDF source, model description or prediction pipeline.
inputs: An iterable of the named input(s) for this model as a dictionary.
- sample_id: the sample id.
+ sample_id: The sample id.
note: `{i}` will be formatted as the i-th sample.
- If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended.
- blocksize_parameter: (optional) tile the input into blocks parametrized by
- blocksize according to any parametrized axis sizes defined in the model RDF
- skip_preprocessing: flag to skip the model's preprocessing
- skip_postprocessing: flag to skip the model's postprocessing
- save_output_path: A path with `{member_id}` `{sample_id}` in it
- to save the output to.
+ If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}`
+ is appended.
+ blocksize_parameter: (optional) Tile the input into blocks parametrized by
+ blocksize according to any parametrized axis sizes defined in the model RDF.
+ skip_preprocessing: Flag to skip the model's preprocessing.
+ skip_postprocessing: Flag to skip the model's postprocessing.
+ save_output_path: A path to save the output to.
+ Must contain:
+ - `{sample_id}` to differentiate predicted samples
+ - `{output_id}` (or `{member_id}`) if the model has multiple outputs
"""
- if save_output_path is not None:
- if "{member_id}" not in str(save_output_path):
- raise ValueError(
- f"Missing `{{member_id}}` in save_output_path={save_output_path}"
- )
-
- if not isinstance(inputs, collections.abc.Mapping) and "{sample_id}" not in str(
- save_output_path
- ):
- raise ValueError(
- f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
- )
+ if save_output_path is not None and "{sample_id}" not in str(save_output_path):
+ raise ValueError(
+ f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
+ + " to differentiate predicted samples."
+ )
if isinstance(model, PredictionPipeline):
pp = model
@@ -180,7 +181,6 @@ def predict_many(
pp = create_prediction_pipeline(model)
if not isinstance(inputs, collections.abc.Mapping):
- sample_id = str(sample_id)
if "{i}" not in sample_id and "{i:" not in sample_id:
sample_id += "{i:03}"
diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py
index eecf47b1..e504bf07 100644
--- a/bioimageio/core/proc_ops.py
+++ b/bioimageio/core/proc_ops.py
@@ -16,7 +16,11 @@
import xarray as xr
from typing_extensions import Self, assert_never
+from bioimageio.core.digest_spec import get_member_id
from bioimageio.spec.model import v0_4, v0_5
+from bioimageio.spec.model.v0_5 import (
+ _convert_proc, # pyright: ignore [reportPrivateUsage]
+)
from ._op_base import BlockedOperator, Operator
from .axis import AxisId, PerAxis
@@ -51,7 +55,7 @@ def _convert_axis_ids(
if mode == "per_sample":
ret = []
elif mode == "per_dataset":
- ret = [AxisId("b")]
+ ret = [v0_5.BATCH_AXIS_ID]
else:
assert_never(mode)
@@ -299,9 +303,15 @@ def from_proc_descr(
member_id: MemberId,
) -> Self:
kwargs = descr.kwargs
- if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
+ if isinstance(kwargs, v0_5.ScaleLinearKwargs):
+ axis = None
+ elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
axis = kwargs.axis
- elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)):
+ elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
+ if kwargs.axes is not None:
+ raise NotImplementedError(
+ "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
+ )
axis = None
else:
assert_never(kwargs)
@@ -605,7 +615,7 @@ def from_proc_descr(
if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
dims = None
elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
- dims = (descr.kwargs.axis,)
+ dims = (AxisId(descr.kwargs.axis),)
else:
assert_never(descr.kwargs)
@@ -657,34 +667,53 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor:
]
-def get_proc_class(proc_spec: ProcDescr):
- if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
- return Binarize
- elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)):
- return Clip
- elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
- return EnsureDtype
- elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
- return FixedZeroMeanUnitVariance
- elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
- return ScaleLinear
+def get_proc(
+ proc_descr: ProcDescr,
+ tensor_descr: Union[
+ v0_4.InputTensorDescr,
+ v0_4.OutputTensorDescr,
+ v0_5.InputTensorDescr,
+ v0_5.OutputTensorDescr,
+ ],
+) -> Processing:
+ member_id = get_member_id(tensor_descr)
+
+ if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
+ return Binarize.from_proc_descr(proc_descr, member_id)
+ elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
+ return Clip.from_proc_descr(proc_descr, member_id)
+ elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
+ return EnsureDtype.from_proc_descr(proc_descr, member_id)
+ elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
+ return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
+ elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
+ return ScaleLinear.from_proc_descr(proc_descr, member_id)
elif isinstance(
- proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
+ proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
):
- return ScaleMeanVariance
- elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
- return ScaleRange
- elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
- return Sigmoid
+ return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
+ elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
+ return ScaleRange.from_proc_descr(proc_descr, member_id)
+ elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
+ return Sigmoid.from_proc_descr(proc_descr, member_id)
elif (
- isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr)
- and proc_spec.kwargs.mode == "fixed"
+ isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
+ and proc_descr.kwargs.mode == "fixed"
):
- return FixedZeroMeanUnitVariance
+ if not isinstance(
+ tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
+ ):
+ raise TypeError(
+ "Expected v0_4 tensor description for v0_4 processing description"
+ )
+
+ v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
+ assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
+ return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
elif isinstance(
- proc_spec,
+ proc_descr,
(v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
):
- return ZeroMeanUnitVariance
+ return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
else:
- assert_never(proc_spec)
+ assert_never(proc_descr)
diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py
index b9afb711..1ada58b2 100644
--- a/bioimageio/core/proc_setup.py
+++ b/bioimageio/core/proc_setup.py
@@ -1,4 +1,5 @@
from typing import (
+ Callable,
Iterable,
List,
Mapping,
@@ -11,15 +12,15 @@
from typing_extensions import assert_never
+from bioimageio.core.digest_spec import get_member_id
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
-from bioimageio.spec.model.v0_5 import TensorId
-from .digest_spec import get_member_ids
from .proc_ops import (
AddKnownDatasetStats,
+ EnsureDtype,
Processing,
UpdateStats,
- get_proc_class,
+ get_proc,
)
from .sample import Sample
from .stat_calculators import StatsCalculator
@@ -45,6 +46,11 @@ class PreAndPostprocessing(NamedTuple):
post: List[Processing]
+class _ProcessingCallables(NamedTuple):
+ pre: Callable[[Sample], None]
+ post: Callable[[Sample], None]
+
+
class _SetupProcessing(NamedTuple):
pre: List[Processing]
post: List[Processing]
@@ -52,6 +58,34 @@ class _SetupProcessing(NamedTuple):
post_measures: Set[Measure]
+class _ApplyProcs:
+ def __init__(self, procs: Sequence[Processing]):
+ super().__init__()
+ self._procs = procs
+
+ def __call__(self, sample: Sample) -> None:
+ for op in self._procs:
+ op(sample)
+
+
+def get_pre_and_postprocessing(
+ model: AnyModelDescr,
+ *,
+ dataset_for_initial_statistics: Iterable[Sample],
+ keep_updating_initial_dataset_stats: bool = False,
+ fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
+) -> _ProcessingCallables:
+ """Creates callables to apply pre- and postprocessing in-place to a sample"""
+
+ setup = setup_pre_and_postprocessing(
+ model=model,
+ dataset_for_initial_statistics=dataset_for_initial_statistics,
+ keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
+ fixed_dataset_stats=fixed_dataset_stats,
+ )
+ return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post))
+
+
def setup_pre_and_postprocessing(
model: AnyModelDescr,
dataset_for_initial_statistics: Iterable[Sample],
@@ -60,7 +94,7 @@ def setup_pre_and_postprocessing(
) -> PreAndPostprocessing:
"""
Get pre- and postprocessing operators for a `model` description.
- userd in `bioimageio.core.create_prediction_pipeline"""
+ Used in `bioimageio.core.create_prediction_pipeline"""
prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
missing_dataset_stats = {
@@ -136,65 +170,63 @@ def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures
)
+def _prepare_procs(
+ tensor_descrs: Union[
+ Sequence[v0_4.InputTensorDescr],
+ Sequence[v0_5.InputTensorDescr],
+ Sequence[v0_4.OutputTensorDescr],
+ Sequence[v0_5.OutputTensorDescr],
+ ],
+) -> List[Processing]:
+ procs: List[Processing] = []
+ for t_descr in tensor_descrs:
+ if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
+ member_id = get_member_id(t_descr)
+ procs.append(
+ EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
+ )
+
+ if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
+ for proc_d in t_descr.preprocessing:
+ procs.append(get_proc(proc_d, t_descr))
+ elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)):
+ for proc_d in t_descr.postprocessing:
+ procs.append(get_proc(proc_d, t_descr))
+ else:
+ assert_never(t_descr)
+
+ if isinstance(
+ t_descr,
+ (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)),
+ ):
+ if len(procs) == 1:
+ # remove initial ensure_dtype if there are no other proccessing steps
+ assert isinstance(procs[0], EnsureDtype)
+ procs = []
+
+ # ensure 0.4 models get float32 input
+ # which has been the implicit assumption for 0.4
+ member_id = get_member_id(t_descr)
+ procs.append(
+ EnsureDtype(input=member_id, output=member_id, dtype="float32")
+ )
+
+ return procs
+
+
def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
- pre_measures: Set[Measure] = set()
- post_measures: Set[Measure] = set()
-
- input_ids = set(get_member_ids(model.inputs))
- output_ids = set(get_member_ids(model.outputs))
-
- def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
- procs: List[Processing] = []
- for t_descr in tensor_descrs:
- if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
- proc_descrs: List[
- Union[
- v0_4.PreprocessingDescr,
- v0_5.PreprocessingDescr,
- v0_4.PostprocessingDescr,
- v0_5.PostprocessingDescr,
- ]
- ] = list(t_descr.preprocessing)
- elif isinstance(
- t_descr,
- (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr),
- ):
- proc_descrs = list(t_descr.postprocessing)
- else:
- assert_never(t_descr)
-
- if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
- ensure_dtype = v0_5.EnsureDtypeDescr(
- kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type)
- )
- if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs:
- proc_descrs.insert(0, ensure_dtype)
-
- proc_descrs.append(ensure_dtype)
-
- for proc_d in proc_descrs:
- proc_class = get_proc_class(proc_d)
- member_id = (
- TensorId(str(t_descr.name))
- if isinstance(t_descr, v0_4.TensorDescrBase)
- else t_descr.id
- )
- req = proc_class.from_proc_descr(
- proc_d, member_id # pyright: ignore[reportArgumentType]
- )
- for m in req.required_measures:
- if m.member_id in input_ids:
- pre_measures.add(m)
- elif m.member_id in output_ids:
- post_measures.add(m)
- else:
- raise ValueError("When to raise ")
- procs.append(req)
- return procs
+ if isinstance(model, v0_4.ModelDescr):
+ pre = _prepare_procs(model.inputs)
+ post = _prepare_procs(model.outputs)
+ elif isinstance(model, v0_5.ModelDescr):
+ pre = _prepare_procs(model.inputs)
+ post = _prepare_procs(model.outputs)
+ else:
+ assert_never(model)
return _SetupProcessing(
- pre=prepare_procs(model.inputs),
- post=prepare_procs(model.outputs),
- pre_measures=pre_measures,
- post_measures=post_measures,
+ pre=pre,
+ post=post,
+ pre_measures={m for proc in pre for m in proc.required_measures},
+ post_measures={m for proc in post for m in proc.required_measures},
)
diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py
index 0620282d..0d4c3724 100644
--- a/bioimageio/core/sample.py
+++ b/bioimageio/core/sample.py
@@ -3,6 +3,7 @@
from dataclasses import dataclass
from math import ceil, floor
from typing import (
+ Any,
Callable,
Dict,
Generic,
@@ -14,6 +15,7 @@
)
import numpy as np
+from numpy.typing import NDArray
from typing_extensions import Self
from .axis import AxisId, PerAxis
@@ -42,21 +44,31 @@
@dataclass
class Sample:
- """A dataset sample"""
+ """A dataset sample.
+
+ A `Sample` has `members`, which allows to combine multiple tensors into a single
+ sample.
+ For example a `Sample` from a dataset with masked images may contain a
+ `MemberId("raw")` and `MemberId("mask")` image.
+ """
members: Dict[MemberId, Tensor]
- """the sample's tensors"""
+ """The sample's tensors"""
stat: Stat
- """sample and dataset statistics"""
+ """Sample and dataset statistics"""
id: SampleId
- """identifier within the sample's dataset"""
+ """Identifies the `Sample` within the dataset -- typically a number or a string."""
@property
def shape(self) -> PerMember[PerAxis[int]]:
return {tid: t.sizes for tid, t in self.members.items()}
+ def as_arrays(self) -> Dict[str, NDArray[Any]]:
+ """Return sample as dictionary of arrays."""
+ return {str(m): t.data.to_numpy() for m, t in self.members.items()}
+
def split_into_blocks(
self,
block_shapes: PerMember[PerAxis[int]],
diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py
index 41233a5b..515fe843 100644
--- a/bioimageio/core/stat_calculators.py
+++ b/bioimageio/core/stat_calculators.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-import collections.abc
+import collections
import warnings
from itertools import product
from typing import (
@@ -26,6 +26,8 @@
from numpy.typing import NDArray
from typing_extensions import assert_never
+from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID
+
from .axis import AxisId, PerAxis
from .common import MemberId
from .sample import Sample
@@ -47,7 +49,7 @@
from .tensor import Tensor
try:
- import crick
+ import crick # pyright: ignore[reportMissingImports]
except Exception:
crick = None
@@ -120,7 +122,7 @@ class MeanVarStdCalculator:
def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
super().__init__()
- self._axes = None if axes is None else tuple(axes)
+ self._axes = None if axes is None else tuple(map(AxisId, axes))
self._member_id = member_id
self._n: int = 0
self._mean: Optional[Tensor] = None
@@ -137,7 +139,15 @@ def compute(
else:
n = int(np.prod([tensor.sizes[d] for d in self._axes]))
- var = xr.dot(c, c, dims=self._axes) / n
+ if xr.__version__.startswith("2023"):
+ var = ( # pyright: ignore[reportUnknownVariableType]
+ xr.dot(c, c, dims=self._axes) / n
+ )
+ else:
+ var = ( # pyright: ignore[reportUnknownVariableType]
+ xr.dot(c, c, dim=self._axes) / n
+ )
+
assert isinstance(var, xr.DataArray)
std = np.sqrt(var)
assert isinstance(std, xr.DataArray)
@@ -152,6 +162,9 @@ def compute(
}
def update(self, sample: Sample):
+ if self._axes is not None and BATCH_AXIS_ID not in self._axes:
+ return
+
tensor = sample.members[self._member_id].astype("float64", copy=False)
mean_b = tensor.mean(dim=self._axes)
assert mean_b.dtype == "float64"
@@ -178,12 +191,16 @@ def update(self, sample: Sample):
def finalize(
self,
) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
- if self._mean is None:
+ if (
+ self._axes is not None
+ and BATCH_AXIS_ID not in self._axes
+ or self._mean is None
+ ):
return {}
else:
assert self._m2 is not None
var = self._m2 / self._n
- sqrt = np.sqrt(var)
+ sqrt = var**0.5
if isinstance(sqrt, (int, float)):
# var and mean are scalar tensors, let's keep it consistent
sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
@@ -306,7 +323,8 @@ def _initialize(self, tensor_sizes: PerAxis[int]):
out_sizes[d] = s
self._dims, self._shape = zip(*out_sizes.items())
- d = int(np.prod(self._shape[1:])) # type: ignore
+ assert self._shape is not None
+ d = int(np.prod(self._shape[1:]))
self._digest = [TDigest() for _ in range(d)]
self._indices = product(*map(range, self._shape[1:]))
diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py
index 57148058..cb3b3da9 100644
--- a/bioimageio/core/tensor.py
+++ b/bioimageio/core/tensor.py
@@ -186,9 +186,20 @@ def dims(self): # TODO: rename to `axes`?
return cast(Tuple[AxisId, ...], self._data.dims)
@property
- def tagged_shape(self):
- """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
- return self.sizes
+ def dtype(self) -> DTypeStr:
+ dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
+ assert dt in get_args(DTypeStr)
+ return dt # pyright: ignore[reportReturnType]
+
+ @property
+ def ndim(self):
+ """Number of tensor dimensions."""
+ return self._data.ndim
+
+ @property
+ def shape(self):
+ """Tuple of tensor axes lengths"""
+ return self._data.shape
@property
def shape_tuple(self):
@@ -203,26 +214,21 @@ def size(self):
"""
return self._data.size
- def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
- """Reduce this Tensor's data by applying sum along some dimension(s)."""
- return self.__class__.from_xarray(self._data.sum(dim=dim))
-
- @property
- def ndim(self):
- """Number of tensor dimensions."""
- return self._data.ndim
-
- @property
- def dtype(self) -> DTypeStr:
- dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
- assert dt in get_args(DTypeStr)
- return dt # pyright: ignore[reportReturnType]
-
@property
def sizes(self):
"""Ordered, immutable mapping from axis ids to axis lengths."""
return cast(Mapping[AxisId, int], self.data.sizes)
+ @property
+ def tagged_shape(self):
+ """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
+ return self.sizes
+
+ def argmax(self) -> Mapping[AxisId, int]:
+ ret = self._data.argmax(...)
+ assert isinstance(ret, dict)
+ return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
+
def astype(self, dtype: DTypeStr, *, copy: bool = False):
"""Return tensor cast to `dtype`
@@ -282,14 +288,23 @@ def crop_to(
def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
- def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
- return self.__class__.from_xarray(self._data.mean(dim=dim))
+ def item(
+ self,
+ key: Union[
+ None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
+ ] = None,
+ ):
+ """Copy a tensor element to a standard Python scalar and return it."""
+ if key is None:
+ ret = self._data.item()
+ else:
+ ret = self[key]._data.item()
- def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
- return self.__class__.from_xarray(self._data.std(dim=dim))
+ assert isinstance(ret, (bool, float, int))
+ return ret
- def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
- return self.__class__.from_xarray(self._data.var(dim=dim))
+ def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
+ return self.__class__.from_xarray(self._data.mean(dim=dim))
def pad(
self,
@@ -405,6 +420,13 @@ def resize_to(
return tensor
+ def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
+ return self.__class__.from_xarray(self._data.std(dim=dim))
+
+ def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
+ """Reduce this Tensor's data by applying sum along some dimension(s)."""
+ return self.__class__.from_xarray(self._data.sum(dim=dim))
+
def transpose(
self,
axes: Sequence[AxisId],
@@ -423,6 +445,9 @@ def transpose(
# transpose to the correct axis order
return self.__class__.from_xarray(array.transpose(*axes))
+ def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
+ return self.__class__.from_xarray(self._data.var(dim=dim))
+
@classmethod
def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
ndim = array.ndim
diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py
index 84e94d38..695f0172 100644
--- a/bioimageio/core/utils/__init__.py
+++ b/bioimageio/core/utils/__init__.py
@@ -2,6 +2,8 @@
import sys
from pathlib import Path
+from ._compare import compare as compare
+
if sys.version_info < (3, 9):
def files(package_name: str):
diff --git a/bioimageio/core/utils/_compare.py b/bioimageio/core/utils/_compare.py
new file mode 100644
index 00000000..b8c673a9
--- /dev/null
+++ b/bioimageio/core/utils/_compare.py
@@ -0,0 +1,30 @@
+from difflib import HtmlDiff, unified_diff
+from typing import Sequence
+
+from typing_extensions import Literal, assert_never
+
+
+def compare(
+ a: Sequence[str],
+ b: Sequence[str],
+ name_a: str = "source",
+ name_b: str = "updated",
+ *,
+ diff_format: Literal["unified", "html"],
+):
+ if diff_format == "html":
+ diff = HtmlDiff().make_file(a, b, name_a, name_b, charset="utf-8")
+ elif diff_format == "unified":
+ diff = "\n".join(
+ unified_diff(
+ a,
+ b,
+ name_a,
+ name_b,
+ lineterm="",
+ )
+ )
+ else:
+ assert_never(diff_format)
+
+ return diff
diff --git a/bioimageio/core/utils/testing.py b/bioimageio/core/utils/testing.py
deleted file mode 100644
index acd65d95..00000000
--- a/bioimageio/core/utils/testing.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# TODO: move to tests/
-from functools import wraps
-from typing import Any, Protocol, Type
-
-
-class test_func(Protocol):
- def __call__(*args: Any, **kwargs: Any): ...
-
-
-def skip_on(exception: Type[Exception], reason: str):
- """adapted from https://stackoverflow.com/a/63522579"""
- import pytest
-
- # Func below is the real decorator and will receive the test function as param
- def decorator_func(f: test_func):
- @wraps(f)
- def wrapper(*args: Any, **kwargs: Any):
- try:
- # Try to run the test
- return f(*args, **kwargs)
- except exception:
- # If exception of given type happens
- # just swallow it and raise pytest.Skip with given reason
- pytest.skip(reason)
-
- return wrapper
-
- return decorator_func
diff --git a/bioimageio/core/weight_converter/__init__.py b/bioimageio/core/weight_converter/__init__.py
deleted file mode 100644
index 5f1674c9..00000000
--- a/bioimageio/core/weight_converter/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""coming soon"""
diff --git a/bioimageio/core/weight_converter/keras/__init__.py b/bioimageio/core/weight_converter/keras/__init__.py
deleted file mode 100644
index 195b42b8..00000000
--- a/bioimageio/core/weight_converter/keras/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO: update keras weight converters
diff --git a/bioimageio/core/weight_converter/keras/_tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py
deleted file mode 100644
index c901f458..00000000
--- a/bioimageio/core/weight_converter/keras/_tensorflow.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# type: ignore # TODO: type
-import os
-import shutil
-from pathlib import Path
-from typing import no_type_check
-from zipfile import ZipFile
-
-try:
- import tensorflow.saved_model
-except Exception:
- tensorflow = None
-
-from bioimageio.spec._internal.io_utils import download
-from bioimageio.spec.model.v0_5 import ModelDescr
-
-
-def _zip_model_bundle(model_bundle_folder: Path):
- zipped_model_bundle = model_bundle_folder.with_suffix(".zip")
-
- with ZipFile(zipped_model_bundle, "w") as zip_obj:
- for root, _, files in os.walk(model_bundle_folder):
- for filename in files:
- src = os.path.join(root, filename)
- zip_obj.write(src, os.path.relpath(src, model_bundle_folder))
-
- try:
- shutil.rmtree(model_bundle_folder)
- except Exception:
- print("TensorFlow bundled model was not removed after compression")
-
- return zipped_model_bundle
-
-
-# adapted from
-# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
-def _convert_tf1(
- keras_weight_path: Path,
- output_path: Path,
- input_name: str,
- output_name: str,
- zip_weights: bool,
-):
- try:
- # try to build the tf model with the keras import from tensorflow
- from bioimageio.core.weight_converter.keras._tensorflow import (
- keras, # type: ignore
- )
-
- except Exception:
- # if the above fails try to export with the standalone keras
- import keras
-
- @no_type_check
- def build_tf_model():
- keras_model = keras.models.load_model(keras_weight_path)
- assert tensorflow is not None
- builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path)
- signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
- inputs={input_name: keras_model.input},
- outputs={output_name: keras_model.output},
- )
-
- signature_def_map = {
- tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
- }
-
- builder.add_meta_graph_and_variables(
- keras.backend.get_session(),
- [tensorflow.saved_model.tag_constants.SERVING],
- signature_def_map=signature_def_map,
- )
- builder.save()
-
- build_tf_model()
-
- if zip_weights:
- output_path = _zip_model_bundle(output_path)
- print("TensorFlow model exported to", output_path)
-
- return 0
-
-
-def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool):
- try:
- # try to build the tf model with the keras import from tensorflow
- from bioimageio.core.weight_converter.keras._tensorflow import keras
- except Exception:
- # if the above fails try to export with the standalone keras
- import keras
-
- model = keras.models.load_model(keras_weight_path)
- keras.models.save_model(model, output_path)
-
- if zip_weights:
- output_path = _zip_model_bundle(output_path)
- print("TensorFlow model exported to", output_path)
-
- return 0
-
-
-def convert_weights_to_tensorflow_saved_model_bundle(
- model: ModelDescr, output_path: Path
-):
- """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'.
-
- Adapted from
- https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
-
- Args:
- model: The bioimageio model description
- output_path: where to save the tensorflow weights. This path must not exist yet.
- """
- assert tensorflow is not None
- tf_major_ver = int(tensorflow.__version__.split(".")[0])
-
- if output_path.suffix == ".zip":
- output_path = output_path.with_suffix("")
- zip_weights = True
- else:
- zip_weights = False
-
- if output_path.exists():
- raise ValueError(f"The ouptut directory at {output_path} must not exist.")
-
- if model.weights.keras_hdf5 is None:
- raise ValueError("Missing Keras Hdf5 weights to convert from.")
-
- weight_spec = model.weights.keras_hdf5
- weight_path = download(weight_spec.source).path
-
- if weight_spec.tensorflow_version:
- model_tf_major_ver = int(weight_spec.tensorflow_version.major)
- if model_tf_major_ver != tf_major_ver:
- raise RuntimeError(
- f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
- )
-
- if tf_major_ver == 1:
- if len(model.inputs) != 1 or len(model.outputs) != 1:
- raise NotImplementedError(
- "Weight conversion for models with multiple inputs or outputs is not yet implemented."
- )
- return _convert_tf1(
- weight_path,
- output_path,
- model.inputs[0].id,
- model.outputs[0].id,
- zip_weights,
- )
- else:
- return _convert_tf2(weight_path, output_path, zip_weights)
diff --git a/bioimageio/core/weight_converter/torch/__init__.py b/bioimageio/core/weight_converter/torch/__init__.py
deleted file mode 100644
index 1b1ba526..00000000
--- a/bioimageio/core/weight_converter/torch/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO: torch weight converters
diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py
deleted file mode 100644
index 3935e1d1..00000000
--- a/bioimageio/core/weight_converter/torch/_onnx.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# type: ignore # TODO: type
-import warnings
-from pathlib import Path
-from typing import Any, List, Sequence, cast
-
-import numpy as np
-from numpy.testing import assert_array_almost_equal
-
-from bioimageio.spec import load_description
-from bioimageio.spec.common import InvalidDescr
-from bioimageio.spec.model import v0_4, v0_5
-
-from ...digest_spec import get_member_id, get_test_inputs
-from ...weight_converter.torch._utils import load_torch_model
-
-try:
- import torch
-except ImportError:
- torch = None
-
-
-def add_onnx_weights(
- model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr",
- *,
- output_path: Path,
- use_tracing: bool = True,
- test_decimal: int = 4,
- verbose: bool = False,
- opset_version: "int | None" = None,
-):
- """Convert model weights from format 'pytorch_state_dict' to 'onnx'.
-
- Args:
- source_model: model without onnx weights
- opset_version: onnx opset version
- use_tracing: whether to use tracing or scripting to export the onnx format
- test_decimal: precision for testing whether the results agree
- """
- if isinstance(model_spec, (str, Path)):
- loaded_spec = load_description(Path(model_spec))
- if isinstance(loaded_spec, InvalidDescr):
- raise ValueError(f"Bad resource description: {loaded_spec}")
- if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)):
- raise TypeError(
- f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr"
- )
- model_spec = loaded_spec
-
- state_dict_weights_descr = model_spec.weights.pytorch_state_dict
- if state_dict_weights_descr is None:
- raise ValueError(
- "The provided model does not have weights in the pytorch state dict format"
- )
-
- assert torch is not None
- with torch.no_grad():
-
- sample = get_test_inputs(model_spec)
- input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs]
- input_tensors = [torch.from_numpy(ipt) for ipt in input_data]
- model = load_torch_model(state_dict_weights_descr)
-
- expected_tensors = model(*input_tensors)
- if isinstance(expected_tensors, torch.Tensor):
- expected_tensors = [expected_tensors]
- expected_outputs: List[np.ndarray[Any, Any]] = [
- out.numpy() for out in expected_tensors
- ]
-
- if use_tracing:
- torch.onnx.export(
- model,
- tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0],
- str(output_path),
- verbose=verbose,
- opset_version=opset_version,
- )
- else:
- raise NotImplementedError
-
- try:
- import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
- except ImportError:
- msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
- warnings.warn(msg)
- return
-
- # check the onnx model
- sess = rt.InferenceSession(str(output_path))
- onnx_input_node_args = cast(
- List[Any], sess.get_inputs()
- ) # fixme: remove cast, try using rt.NodeArg instead of Any
- onnx_inputs = {
- input_name.name: inp
- for input_name, inp in zip(onnx_input_node_args, input_data)
- }
- outputs = cast(
- Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)
- ) # FIXME: remove cast
-
- try:
- for exp, out in zip(expected_outputs, outputs):
- assert_array_almost_equal(exp, out, decimal=test_decimal)
- return 0
- except AssertionError as e:
- msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}"
- warnings.warn(msg)
- return 1
diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py
deleted file mode 100644
index 5ca16069..00000000
--- a/bioimageio/core/weight_converter/torch/_torchscript.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# type: ignore # TODO: type
-from pathlib import Path
-from typing import List, Sequence, Union
-
-import numpy as np
-from numpy.testing import assert_array_almost_equal
-from typing_extensions import Any, assert_never
-
-from bioimageio.spec.model import v0_4, v0_5
-from bioimageio.spec.model.v0_5 import Version
-
-from ._utils import load_torch_model
-
-try:
- import torch
-except ImportError:
- torch = None
-
-
-# FIXME: remove Any
-def _check_predictions(
- model: Any,
- scripted_model: Any,
- model_spec: "v0_4.ModelDescr | v0_5.ModelDescr",
- input_data: Sequence["torch.Tensor"],
-):
- assert torch is not None
-
- def _check(input_: Sequence[torch.Tensor]) -> None:
- expected_tensors = model(*input_)
- if isinstance(expected_tensors, torch.Tensor):
- expected_tensors = [expected_tensors]
- expected_outputs: List[np.ndarray[Any, Any]] = [
- out.numpy() for out in expected_tensors
- ]
-
- output_tensors = scripted_model(*input_)
- if isinstance(output_tensors, torch.Tensor):
- output_tensors = [output_tensors]
- outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors]
-
- try:
- for exp, out in zip(expected_outputs, outputs):
- assert_array_almost_equal(exp, out, decimal=4)
- except AssertionError as e:
- raise ValueError(
- f"Results before and after weights conversion do not agree:\n {str(e)}"
- )
-
- _check(input_data)
-
- if len(model_spec.inputs) > 1:
- return # FIXME: why don't we check multiple inputs?
-
- input_descr = model_spec.inputs[0]
- if isinstance(input_descr, v0_4.InputTensorDescr):
- if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape):
- return
- min_shape = input_descr.shape.min
- step = input_descr.shape.step
- else:
- min_shape: List[int] = []
- step: List[int] = []
- for axis in input_descr.axes:
- if isinstance(axis.size, v0_5.ParameterizedSize):
- min_shape.append(axis.size.min)
- step.append(axis.size.step)
- elif isinstance(axis.size, int):
- min_shape.append(axis.size)
- step.append(0)
- elif axis.size is None:
- raise NotImplementedError(
- f"Can't verify inputs that don't specify their shape fully: {axis}"
- )
- elif isinstance(axis.size, v0_5.SizeReference):
- raise NotImplementedError(f"Can't handle axes like '{axis}' yet")
- else:
- assert_never(axis.size)
-
- half_step = [st // 2 for st in step]
- max_steps = 4
-
- # check that input and output agree for decreasing input sizes
- for step_factor in range(1, max_steps + 1):
- slice_ = tuple(
- slice(None) if st == 0 else slice(step_factor * st, -step_factor * st)
- for st in half_step
- )
- this_input = [inp[slice_] for inp in input_data]
- this_shape = this_input[0].shape
- if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)):
- raise ValueError(
- f"Mismatched shapes: {this_shape}. Expected at least {min_shape}"
- )
- _check(this_input)
-
-
-def convert_weights_to_torchscript(
- model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr],
- output_path: Path,
- use_tracing: bool = True,
-) -> v0_5.TorchscriptWeightsDescr:
- """Convert model weights from format 'pytorch_state_dict' to 'torchscript'.
-
- Args:
- model_descr: location of the resource for the input bioimageio model
- output_path: where to save the torchscript weights
- use_tracing: whether to use tracing or scripting to export the torchscript format
- """
-
- state_dict_weights_descr = model_descr.weights.pytorch_state_dict
- if state_dict_weights_descr is None:
- raise ValueError(
- "The provided model does not have weights in the pytorch state dict format"
- )
-
- input_data = model_descr.get_input_test_arrays()
-
- with torch.no_grad():
- input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data]
-
- model = load_torch_model(state_dict_weights_descr)
-
- # FIXME: remove Any
- if use_tracing:
- scripted_model: Any = torch.jit.trace(model, input_data)
- else:
- scripted_model: Any = torch.jit.script(model)
-
- _check_predictions(
- model=model,
- scripted_model=scripted_model,
- model_spec=model_descr,
- input_data=input_data,
- )
-
- # save the torchscript model
- scripted_model.save(
- str(output_path)
- ) # does not support Path, so need to cast to str
-
- return v0_5.TorchscriptWeightsDescr(
- source=output_path,
- pytorch_version=Version(torch.__version__),
- parent="pytorch_state_dict",
- )
diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py
deleted file mode 100644
index 01df0747..00000000
--- a/bioimageio/core/weight_converter/torch/_utils.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from typing import Union
-
-from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
-from bioimageio.spec.model import v0_4, v0_5
-from bioimageio.spec.utils import download
-
-try:
- import torch
-except ImportError:
- torch = None
-
-
-# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
-# and for each weight format
-def load_torch_model( # pyright: ignore[reportUnknownParameterType]
- node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr],
-):
- assert torch is not None
- model = ( # pyright: ignore[reportUnknownVariableType]
- PytorchModelAdapter.get_network(node)
- )
- state = torch.load(download(node.source).path, map_location="cpu")
- model.load_state_dict(state) # FIXME: check incompatible keys?
- return model.eval() # pyright: ignore[reportUnknownVariableType]
diff --git a/bioimageio/core/weight_converters/__init__.py b/bioimageio/core/weight_converters/__init__.py
new file mode 100644
index 00000000..31a91642
--- /dev/null
+++ b/bioimageio/core/weight_converters/__init__.py
@@ -0,0 +1,3 @@
+from ._add_weights import add_weights
+
+__all__ = ["add_weights"]
diff --git a/bioimageio/core/weight_converters/_add_weights.py b/bioimageio/core/weight_converters/_add_weights.py
new file mode 100644
index 00000000..978c8450
--- /dev/null
+++ b/bioimageio/core/weight_converters/_add_weights.py
@@ -0,0 +1,173 @@
+import traceback
+from typing import Optional
+
+from loguru import logger
+from pydantic import DirectoryPath
+
+from bioimageio.spec import (
+ InvalidDescr,
+ load_model_description,
+ save_bioimageio_package_as_folder,
+)
+from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
+
+from .._resource_tests import load_description_and_test
+
+
+def add_weights(
+ model_descr: ModelDescr,
+ *,
+ output_path: DirectoryPath,
+ source_format: Optional[WeightsFormat] = None,
+ target_format: Optional[WeightsFormat] = None,
+ verbose: bool = False,
+) -> Optional[ModelDescr]:
+ """Convert model weights to other formats and add them to the model description
+
+ Args:
+ output_path: Path to save updated model package to.
+ source_format: convert from a specific weights format.
+ Default: choose automatically from any available.
+ target_format: convert to a specific weights format.
+ Default: attempt to convert to any missing format.
+ devices: Devices that may be used during conversion.
+ verbose: log more (error) output
+
+ Returns:
+ - An updated model description if any converted weights were added.
+ - `None` if no conversion was possible.
+ """
+ if not isinstance(model_descr, ModelDescr):
+ if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr):
+ raise TypeError(
+ f"Model format {model_descr.format} is not supported, please update"
+ + f" model to format {ModelDescr.implemented_format_version} first."
+ )
+
+ raise TypeError(type(model_descr))
+
+ # save model to local folder
+ output_path = save_bioimageio_package_as_folder(
+ model_descr, output_path=output_path
+ )
+ # reload from local folder to make sure we do not edit the given model
+ _model_descr = load_model_description(output_path, perform_io_checks=False)
+ assert isinstance(_model_descr, ModelDescr)
+ model_descr = _model_descr
+ del _model_descr
+
+ if source_format is None:
+ available = set(model_descr.weights.available_formats)
+ else:
+ available = {source_format}
+
+ if target_format is None:
+ missing = set(model_descr.weights.missing_formats)
+ else:
+ missing = {target_format}
+
+ originally_missing = set(missing)
+
+ if "pytorch_state_dict" in available and "torchscript" in missing:
+ logger.info(
+ "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'."
+ )
+ from .pytorch_to_torchscript import convert
+
+ try:
+ torchscript_weights_path = output_path / "weights_torchscript.pt"
+ model_descr.weights.torchscript = convert(
+ model_descr,
+ output_path=torchscript_weights_path,
+ use_tracing=False,
+ )
+ except Exception as e:
+ if verbose:
+ traceback.print_exception(e)
+
+ logger.error(e)
+ else:
+ available.add("torchscript")
+ missing.discard("torchscript")
+
+ if "pytorch_state_dict" in available and "torchscript" in missing:
+ logger.info(
+ "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
+ )
+ from .pytorch_to_torchscript import convert
+
+ try:
+ torchscript_weights_path = output_path / "weights_torchscript_traced.pt"
+
+ model_descr.weights.torchscript = convert(
+ model_descr,
+ output_path=torchscript_weights_path,
+ use_tracing=True,
+ )
+ except Exception as e:
+ if verbose:
+ traceback.print_exception(e)
+
+ logger.error(e)
+ else:
+ available.add("torchscript")
+ missing.discard("torchscript")
+
+ if "torchscript" in available and "onnx" in missing:
+ logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
+ from .torchscript_to_onnx import convert
+
+ try:
+ onnx_weights_path = output_path / "weights.onnx"
+ model_descr.weights.onnx = convert(
+ model_descr,
+ output_path=onnx_weights_path,
+ )
+ except Exception as e:
+ if verbose:
+ traceback.print_exception(e)
+
+ logger.error(e)
+ else:
+ available.add("onnx")
+ missing.discard("onnx")
+
+ if "pytorch_state_dict" in available and "onnx" in missing:
+ logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
+ from .pytorch_to_onnx import convert
+
+ try:
+ onnx_weights_path = output_path / "weights.onnx"
+
+ model_descr.weights.onnx = convert(
+ model_descr,
+ output_path=onnx_weights_path,
+ verbose=verbose,
+ )
+ except Exception as e:
+ if verbose:
+ traceback.print_exception(e)
+
+ logger.error(e)
+ else:
+ available.add("onnx")
+ missing.discard("onnx")
+
+ if missing:
+ logger.warning(
+ f"Converting from any of the available weights formats {available} to any"
+ + f" of {missing} failed or is not yet implemented. Please create an issue"
+ + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
+ + " if you would like bioimageio.core to support a particular conversion."
+ )
+
+ if originally_missing == missing:
+ logger.warning("failed to add any converted weights")
+ return None
+ else:
+ logger.info("added weights formats {}", originally_missing - missing)
+ # resave model with updated rdf.yaml
+ _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
+ tested_model_descr = load_description_and_test(model_descr)
+ assert isinstance(tested_model_descr, ModelDescr)
+ return tested_model_descr
diff --git a/bioimageio/core/weight_converters/_utils_onnx.py b/bioimageio/core/weight_converters/_utils_onnx.py
new file mode 100644
index 00000000..3c45d245
--- /dev/null
+++ b/bioimageio/core/weight_converters/_utils_onnx.py
@@ -0,0 +1,15 @@
+from collections import defaultdict
+from itertools import chain
+from typing import DefaultDict, Dict
+
+from bioimageio.spec.model.v0_5 import ModelDescr
+
+
+def get_dynamic_axes(model_descr: ModelDescr):
+ dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict)
+ for d in chain(model_descr.inputs, model_descr.outputs):
+ for i, ax in enumerate(d.axes):
+ if not isinstance(ax.size, int):
+ dynamic_axes[str(d.id)][i] = str(ax.id)
+
+ return dynamic_axes
diff --git a/bioimageio/core/weight_converters/keras_to_tensorflow.py b/bioimageio/core/weight_converters/keras_to_tensorflow.py
new file mode 100644
index 00000000..ac8886e1
--- /dev/null
+++ b/bioimageio/core/weight_converters/keras_to_tensorflow.py
@@ -0,0 +1,183 @@
+import os
+import shutil
+from pathlib import Path
+from typing import Union, no_type_check
+from zipfile import ZipFile
+
+import tensorflow # pyright: ignore[reportMissingTypeStubs]
+
+from bioimageio.spec._internal.io import download
+from bioimageio.spec._internal.version_type import Version
+from bioimageio.spec.common import ZipPath
+from bioimageio.spec.model.v0_5 import (
+ InputTensorDescr,
+ ModelDescr,
+ OutputTensorDescr,
+ TensorflowSavedModelBundleWeightsDescr,
+)
+
+from .. import __version__
+from ..io import ensure_unzipped
+
+try:
+ # try to build the tf model with the keras import from tensorflow
+ from tensorflow import keras # type: ignore
+except Exception:
+ # if the above fails try to export with the standalone keras
+ import keras # pyright: ignore[reportMissingTypeStubs]
+
+
+def convert(
+ model_descr: ModelDescr, output_path: Path
+) -> TensorflowSavedModelBundleWeightsDescr:
+ """
+ Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format.
+
+ This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle,
+ which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x
+ and 2.x versions, with appropriate checks to ensure compatibility.
+
+ Adapted from:
+ https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
+
+ Args:
+ model_descr:
+ The bioimage.io model description containing the model's metadata and weights.
+ output_path:
+ Path with .zip suffix (.zip is appended otherwise) to which a zip archive
+ with the TensorFlow SavedModel bundle will be saved.
+ Raises:
+ ValueError:
+ - If the specified `output_path` already exists.
+ - If the Keras HDF5 weights are missing in the model description.
+ RuntimeError:
+ If there is a mismatch between the TensorFlow version used by the model and the version installed.
+ NotImplementedError:
+ If the model has multiple inputs or outputs and TensorFlow 1.x is being used.
+
+ Returns:
+ A descriptor object containing information about the converted TensorFlow SavedModel bundle.
+ """
+ tf_major_ver = int(tensorflow.__version__.split(".")[0])
+
+ if output_path.suffix != ".zip":
+ output_path = output_path.with_suffix("")
+
+ if output_path.exists():
+ raise ValueError(f"The ouptut directory at {output_path} must not exist.")
+
+ if model_descr.weights.keras_hdf5 is None:
+ raise ValueError("Missing Keras Hdf5 weights to convert from.")
+
+ weight_spec = model_descr.weights.keras_hdf5
+ weight_path = download(weight_spec.source).path
+
+ if weight_spec.tensorflow_version:
+ model_tf_major_ver = int(weight_spec.tensorflow_version.major)
+ if model_tf_major_ver != tf_major_ver:
+ raise RuntimeError(
+ f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
+ )
+
+ if tf_major_ver == 1:
+ if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1:
+ raise NotImplementedError(
+ "Weight conversion for models with multiple inputs or outputs is not yet implemented."
+ )
+
+ input_name = str(
+ d.id
+ if isinstance((d := model_descr.inputs[0]), InputTensorDescr)
+ else d.name
+ )
+ output_name = str(
+ d.id
+ if isinstance((d := model_descr.outputs[0]), OutputTensorDescr)
+ else d.name
+ )
+ return _convert_tf1(
+ ensure_unzipped(weight_path, Path("bioimageio_unzipped_tf_weights")),
+ output_path,
+ input_name,
+ output_name,
+ )
+ else:
+ return _convert_tf2(weight_path, output_path)
+
+
+def _convert_tf2(
+ keras_weight_path: Union[Path, ZipPath], output_path: Path
+) -> TensorflowSavedModelBundleWeightsDescr:
+ model = keras.models.load_model(keras_weight_path) # type: ignore
+ model.export(output_path) # type: ignore
+
+ output_path = _zip_model_bundle(output_path)
+ print("TensorFlow model exported to", output_path)
+
+ return TensorflowSavedModelBundleWeightsDescr(
+ source=output_path,
+ parent="keras_hdf5",
+ tensorflow_version=Version(tensorflow.__version__),
+ comment=f"Converted with bioimageio.core {__version__}.",
+ )
+
+
+# adapted from
+# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
+def _convert_tf1(
+ keras_weight_path: Path,
+ output_path: Path,
+ input_name: str,
+ output_name: str,
+) -> TensorflowSavedModelBundleWeightsDescr:
+
+ @no_type_check
+ def build_tf_model():
+ keras_model = keras.models.load_model(keras_weight_path)
+ builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path)
+ signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
+ inputs={input_name: keras_model.input},
+ outputs={output_name: keras_model.output},
+ )
+
+ signature_def_map = {
+ tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: (
+ signature
+ )
+ }
+
+ builder.add_meta_graph_and_variables(
+ keras.backend.get_session(),
+ [tensorflow.saved_model.tag_constants.SERVING],
+ signature_def_map=signature_def_map,
+ )
+ builder.save()
+
+ build_tf_model()
+
+ output_path = _zip_model_bundle(output_path)
+ print("TensorFlow model exported to", output_path)
+
+ return TensorflowSavedModelBundleWeightsDescr(
+ source=output_path,
+ parent="keras_hdf5",
+ tensorflow_version=Version(tensorflow.__version__),
+ comment=f"Converted with bioimageio.core {__version__}.",
+ )
+
+
+def _zip_model_bundle(model_bundle_folder: Path):
+ zipped_model_bundle = model_bundle_folder.with_suffix(".zip")
+
+ with ZipFile(zipped_model_bundle, "w") as zip_obj:
+ for root, _, files in os.walk(model_bundle_folder):
+ for filename in files:
+ src = os.path.join(root, filename)
+ zip_obj.write(src, os.path.relpath(src, model_bundle_folder))
+
+ try:
+ shutil.rmtree(model_bundle_folder)
+ except Exception:
+ print("TensorFlow bundled model was not removed after compression")
+
+ return zipped_model_bundle
diff --git a/bioimageio/core/weight_converters/pytorch_to_onnx.py b/bioimageio/core/weight_converters/pytorch_to_onnx.py
new file mode 100644
index 00000000..72d819b1
--- /dev/null
+++ b/bioimageio/core/weight_converters/pytorch_to_onnx.py
@@ -0,0 +1,79 @@
+from pathlib import Path
+
+import torch
+
+from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
+
+from .. import __version__
+from ..backends.pytorch_backend import load_torch_model
+from ..digest_spec import get_member_id, get_test_inputs
+from ..proc_setup import get_pre_and_postprocessing
+from ._utils_onnx import get_dynamic_axes
+
+
+def convert(
+ model_descr: ModelDescr,
+ output_path: Path,
+ *,
+ verbose: bool = False,
+ opset_version: int = 15,
+) -> OnnxWeightsDescr:
+ """
+ Convert model weights from the Torchscript state_dict format to the ONNX format.
+
+ Args:
+ model_descr:
+ The model description object that contains the model and its weights.
+ output_path:
+ The file path where the ONNX model will be saved.
+ verbose:
+ If True, will print out detailed information during the ONNX export process. Defaults to False.
+ opset_version:
+ The ONNX opset version to use for the export. Defaults to 15.
+
+ Raises:
+ ValueError:
+ If the provided model does not have weights in the PyTorch state_dict format.
+
+ Returns:
+ A descriptor object that contains information about the exported ONNX weights.
+ """
+
+ state_dict_weights_descr = model_descr.weights.pytorch_state_dict
+ if state_dict_weights_descr is None:
+ raise ValueError(
+ "The provided model does not have weights in the pytorch state dict format"
+ )
+
+ sample = get_test_inputs(model_descr)
+ procs = get_pre_and_postprocessing(
+ model_descr, dataset_for_initial_statistics=[sample]
+ )
+ procs.pre(sample)
+ inputs_numpy = [
+ sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
+ ]
+ inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
+ model = load_torch_model(state_dict_weights_descr, load_state=True)
+ with torch.no_grad():
+ outputs_original_torch = model(*inputs_torch)
+ if isinstance(outputs_original_torch, torch.Tensor):
+ outputs_original_torch = [outputs_original_torch]
+
+ _ = torch.onnx.export(
+ model,
+ tuple(inputs_torch),
+ str(output_path),
+ input_names=[str(d.id) for d in model_descr.inputs],
+ output_names=[str(d.id) for d in model_descr.outputs],
+ dynamic_axes=get_dynamic_axes(model_descr),
+ verbose=verbose,
+ opset_version=opset_version,
+ )
+
+ return OnnxWeightsDescr(
+ source=output_path,
+ parent="pytorch_state_dict",
+ opset_version=opset_version,
+ comment=f"Converted with bioimageio.core {__version__}.",
+ )
diff --git a/bioimageio/core/weight_converters/pytorch_to_torchscript.py b/bioimageio/core/weight_converters/pytorch_to_torchscript.py
new file mode 100644
index 00000000..3d0f281c
--- /dev/null
+++ b/bioimageio/core/weight_converters/pytorch_to_torchscript.py
@@ -0,0 +1,70 @@
+from pathlib import Path
+from typing import Any, Tuple, Union
+
+import torch
+from torch.jit import ScriptModule
+
+from bioimageio.spec._internal.version_type import Version
+from bioimageio.spec.model.v0_5 import ModelDescr, TorchscriptWeightsDescr
+
+from .. import __version__
+from ..backends.pytorch_backend import load_torch_model
+
+
+def convert(
+ model_descr: ModelDescr,
+ output_path: Path,
+ *,
+ use_tracing: bool = True,
+) -> TorchscriptWeightsDescr:
+ """
+ Convert model weights from the PyTorch `state_dict` format to TorchScript.
+
+ Args:
+ model_descr:
+ The model description object that contains the model and its weights in the PyTorch `state_dict` format.
+ output_path:
+ The file path where the TorchScript model will be saved.
+ use_tracing:
+ Whether to use tracing or scripting to export the TorchScript format.
+ - `True`: Use tracing, which is recommended for models with straightforward control flow.
+ - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals).
+
+ Raises:
+ ValueError:
+ If the provided model does not have weights in the PyTorch `state_dict` format.
+
+ Returns:
+ A descriptor object that contains information about the exported TorchScript weights.
+ """
+ state_dict_weights_descr = model_descr.weights.pytorch_state_dict
+ if state_dict_weights_descr is None:
+ raise ValueError(
+ "The provided model does not have weights in the pytorch state dict format"
+ )
+
+ input_data = model_descr.get_input_test_arrays()
+
+ with torch.no_grad():
+ input_data = [torch.from_numpy(inp) for inp in input_data]
+ model = load_torch_model(state_dict_weights_descr, load_state=True)
+ scripted_model: Union[ # pyright: ignore[reportUnknownVariableType]
+ ScriptModule, Tuple[Any, ...]
+ ] = (
+ torch.jit.trace(model, input_data)
+ if use_tracing
+ else torch.jit.script(model)
+ )
+ assert not isinstance(scripted_model, tuple), scripted_model
+
+ scripted_model.save(output_path)
+
+ return TorchscriptWeightsDescr(
+ source=output_path,
+ pytorch_version=Version(torch.__version__),
+ parent="pytorch_state_dict",
+ comment=(
+ f"Converted with bioimageio.core {__version__}"
+ + f" with use_tracing={use_tracing}."
+ ),
+ )
diff --git a/bioimageio/core/weight_converters/torchscript_to_onnx.py b/bioimageio/core/weight_converters/torchscript_to_onnx.py
new file mode 100644
index 00000000..d58b47ab
--- /dev/null
+++ b/bioimageio/core/weight_converters/torchscript_to_onnx.py
@@ -0,0 +1,84 @@
+from pathlib import Path
+
+import torch.jit
+
+from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
+from bioimageio.spec.utils import download
+
+from .. import __version__
+from ..digest_spec import get_member_id, get_test_inputs
+from ..proc_setup import get_pre_and_postprocessing
+from ._utils_onnx import get_dynamic_axes
+
+
+def convert(
+ model_descr: ModelDescr,
+ output_path: Path,
+ *,
+ verbose: bool = False,
+ opset_version: int = 15,
+) -> OnnxWeightsDescr:
+ """
+ Convert model weights from the PyTorch state_dict format to the ONNX format.
+
+ Args:
+ model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]):
+ The model description object that contains the model and its weights.
+ output_path (Path):
+ The file path where the ONNX model will be saved.
+ verbose (bool, optional):
+ If True, will print out detailed information during the ONNX export process. Defaults to False.
+ opset_version (int, optional):
+ The ONNX opset version to use for the export. Defaults to 15.
+ Raises:
+ ValueError:
+ If the provided model does not have weights in the torchscript format.
+
+ Returns:
+ v0_5.OnnxWeightsDescr:
+ A descriptor object that contains information about the exported ONNX weights.
+ """
+
+ torchscript_descr = model_descr.weights.torchscript
+ if torchscript_descr is None:
+ raise ValueError(
+ "The provided model does not have weights in the torchscript format"
+ )
+
+ sample = get_test_inputs(model_descr)
+ procs = get_pre_and_postprocessing(
+ model_descr, dataset_for_initial_statistics=[sample]
+ )
+ procs.pre(sample)
+ inputs_numpy = [
+ sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
+ ]
+ inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
+
+ weight_path = download(torchscript_descr).path
+ model = torch.jit.load(weight_path) # type: ignore
+ model.to("cpu")
+ model = model.eval() # type: ignore
+
+ with torch.no_grad():
+ outputs_original_torch = model(*inputs_torch) # type: ignore
+ if isinstance(outputs_original_torch, torch.Tensor):
+ outputs_original_torch = [outputs_original_torch]
+
+ _ = torch.onnx.export(
+ model, # type: ignore
+ tuple(inputs_torch),
+ str(output_path),
+ input_names=[str(d.id) for d in model_descr.inputs],
+ output_names=[str(d.id) for d in model_descr.outputs],
+ dynamic_axes=get_dynamic_axes(model_descr),
+ verbose=verbose,
+ opset_version=opset_version,
+ )
+
+ return OnnxWeightsDescr(
+ source=output_path,
+ parent="torchscript",
+ opset_version=opset_version,
+ comment=f"Converted with bioimageio.core {__version__}.",
+ )
diff --git a/dev/env-wo-python.yaml b/dev/env-dev.yaml
similarity index 59%
rename from dev/env-wo-python.yaml
rename to dev/env-dev.yaml
index d8cba289..13378376 100644
--- a/dev/env-wo-python.yaml
+++ b/dev/env-dev.yaml
@@ -1,21 +1,23 @@
-# modified copy of env.yaml
+# modified copy of env-full.yaml wo dependencies 'for model testing'
name: core
channels:
- conda-forge
- nodefaults
- - pytorch # added
+ - pytorch
dependencies:
- - bioimageio.spec>=0.5.3.5
+ - bioimageio.spec==0.5.4.1
- black
# - crick # currently requires python<=3.9
- - filelock
- h5py
+ - imagecodecs
- imageio>=2.5
- jupyter
- jupyter-black
- - keras>=3.0
+ - keras>=3.0,<4
- loguru
+ - matplotlib
- numpy
+ - onnx
- onnxruntime
- packaging>=17.0
- pdoc
@@ -27,16 +29,16 @@ dependencies:
- pyright
- pytest
- pytest-cov
- - pytest-xdist
- # - python=3.9 # removed
- - pytorch>=2.1
+ # - python=3.11 # removed
+ - pytorch>=2.1,<3
- requests
- rich
- ruff
- ruyaml
+ - tensorflow>=2,<3
- torchvision
- tqdm
- typing-extensions
- - xarray
+ - xarray>=2024.01,<2025.3.0
- pip:
- -e ..
diff --git a/dev/env-full.yaml b/dev/env-full.yaml
new file mode 100644
index 00000000..a9dc0132
--- /dev/null
+++ b/dev/env-full.yaml
@@ -0,0 +1,49 @@
+name: core-full
+channels:
+ - conda-forge
+ - nodefaults
+ - pytorch
+dependencies:
+ - bioimageio.spec==0.5.4.1
+ - black
+ # - careamics # TODO: add careamics for model testing (currently pins pydantic to <2.9)
+ - cellpose # for model testing
+ # - crick # currently requires python<=3.9
+ - h5py
+ - imagecodecs
+ - imageio>=2.5
+ - jupyter
+ - jupyter-black
+ - keras>=3.0,<4
+ - loguru
+ - matplotlib
+ - monai # for model testing
+ - numpy
+ - onnx
+ - onnxruntime
+ - packaging>=17.0
+ - pdoc
+ - pip
+ - pre-commit
+ - psutil
+ - pydantic
+ - pydantic-settings
+ - pyright
+ - pytest
+ - pytest-cov
+ - python=3.11 # 3.12 not supported by cellpose->fastremap
+ - pytorch>=2.1,<3
+ - requests
+ - rich
+ - ruff
+ - ruyaml
+ - segment-anything # for model testing
+ - tensorflow>=2,<3
+ - timm # for model testing
+ - torchvision>=0.21
+ - tqdm
+ - typing-extensions
+ - xarray>=2024.01,<2025.3.0
+ - pip:
+ - git+https://github.com/ChaoningZhang/MobileSAM.git # for model testing
+ - -e ..
diff --git a/dev/env-gpu.yaml b/dev/env-gpu.yaml
new file mode 100644
index 00000000..7fc2123c
--- /dev/null
+++ b/dev/env-gpu.yaml
@@ -0,0 +1,52 @@
+# version of enf-full for running on GPU
+name: core-gpu
+channels:
+ - conda-forge
+ - nodefaults
+dependencies:
+ - bioimageio.spec==0.5.4.1
+ - black
+ - cellpose # for model testing
+ # - crick # currently requires python<=3.9
+ - h5py
+ - imagecodecs
+ - imageio>=2.5
+ - jupyter
+ - jupyter-black
+ - keras>=3.0,<4
+ - loguru
+ - matplotlib
+ - monai # for model testing
+ - numpy
+ - onnx
+ - packaging>=17.0
+ - pdoc
+ - pip
+ - pre-commit
+ - psutil
+ - pydantic<2.9
+ - pydantic-settings
+ - pyright
+ - pytest
+ - pytest-cov
+ - python=3.11
+ - requests
+ - rich
+ - ruff
+ - ruyaml
+ - segment-anything # for model testing
+ - timm # for model testing
+ - tqdm
+ - typing-extensions
+ - xarray>=2024.01,<2025.3.0
+ - pip:
+ # - tf2onnx # TODO: add tf2onnx
+ - --extra-index-url https://download.pytorch.org/whl/cu126
+ - careamics # TODO: add careamics for model testing (currently pins pydantic to <2.9)
+ - git+https://github.com/ChaoningZhang/MobileSAM.git # for model testing
+ - onnxruntime-gpu
+ - tensorflow
+ - torch
+ - torchaudio
+ - torchvision>=0.21
+ - -e ..
diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml
index 22353103..6fc6597a 100644
--- a/dev/env-py38.yaml
+++ b/dev/env-py38.yaml
@@ -1,20 +1,23 @@
-# manipulated copy of env.yaml
-name: core38
+# manipulated copy of env-full.yaml wo dependencies 'for model testing' for python 3.8
+name: core-py38
channels:
- conda-forge
- nodefaults
+ - pytorch
dependencies:
- - bioimageio.spec>=0.5.3.5
+ - bioimageio.spec==0.5.4.1
- black
- crick # uncommented
- - filelock
- h5py
+ - imagecodecs
- imageio>=2.5
- jupyter
- jupyter-black
- # - keras>=3.0 # removed
+ # - keras>=3.0,<4 # removed
- loguru
+ - matplotlib
- numpy
+ - onnx
- onnxruntime
- packaging>=17.0
- pdoc
@@ -26,16 +29,16 @@ dependencies:
- pyright
- pytest
- pytest-cov
- - pytest-xdist
- python=3.8 # changed
- - pytorch>=2.1
+ - pytorch>=2.1,<3
- requests
- rich
- ruff
- ruyaml
+ # - tensorflow>=2,<3 removed
- torchvision
- tqdm
- typing-extensions
- - xarray
+ - xarray>=2023.01,<2025.3.0
- pip:
- -e ..
diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml
deleted file mode 100644
index 0df6fd07..00000000
--- a/dev/env-tf.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-# modified copy of env.yaml
-name: core-tf # changed
-channels:
- - conda-forge
- - nodefaults
-dependencies:
- - bioimageio.spec>=0.5.3.5
- - black
- # - crick # currently requires python<=3.9
- - filelock
- - h5py
- - imageio>=2.5
- - jupyter
- - jupyter-black
- - keras>=2.15 # changed
- - loguru
- - numpy
- - onnxruntime
- - packaging>=17.0
- - pdoc
- - pip
- - pre-commit
- - psutil
- - pydantic
- - pydantic-settings
- - pyright
- - pytest
- - pytest-cov
- - pytest-xdist
- # - python=3.9 # removed
- # - pytorch>=2.1 # removed
- - requests
- - rich
- # - ruff # removed
- - ruyaml
- - tensorflow>=2.15 # added
- # - torchvision # removed
- - tqdm
- - typing-extensions
- - xarray
- - pip:
- - -e ..
diff --git a/dev/env.yaml b/dev/env.yaml
deleted file mode 100644
index 20d60a18..00000000
--- a/dev/env.yaml
+++ /dev/null
@@ -1,41 +0,0 @@
-name: core
-channels:
- - conda-forge
-dependencies:
- - bioimageio.spec>=0.5.3.5
- - black
- # - crick # currently requires python<=3.9
- - filelock
- - h5py
- - imageio>=2.5
- - jupyter
- - jupyter-black
- - ipykernel
- - matplotlib
- - keras>=3.0
- - loguru
- - numpy
- - onnxruntime
- - packaging>=17.0
- - pdoc
- - pip
- - pre-commit
- - psutil
- - pydantic
- - pydantic-settings
- - pyright
- - pytest
- - pytest-cov
- - pytest-xdist
- - python=3.9
- - pytorch>=2.1
- - requests
- - rich
- - ruff
- - ruyaml
- - torchvision
- - tqdm
- - typing-extensions
- - xarray
- - pip:
- - -e ..
diff --git a/presentations/create_ambitious_sloth.ipynb b/presentations/create_ambitious_sloth.ipynb
index 171b30db..8cda7fec 100644
--- a/presentations/create_ambitious_sloth.ipynb
+++ b/presentations/create_ambitious_sloth.ipynb
@@ -465,7 +465,7 @@
}
],
"source": [
- "pytorch_weights = torch.load(root / \"weights.pt\", weights_only=False)\n",
+ "pytorch_weights = torch.load(root / \"weights.pt\", weights_only=True)\n",
"pprint([(k, tuple(v.shape)) for k, v in pytorch_weights.items()][:4] + [\"...\"])"
]
},
diff --git a/pyproject.toml b/pyproject.toml
index 91cd2cbc..5d58fe72 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.black]
line-length = 88
-extend_exclude = "/presentations/"
+extend-exclude = "/presentations/"
target-version = ["py38", "py39", "py310", "py311", "py312"]
preview = true
@@ -8,6 +8,7 @@ preview = true
exclude = [
"**/__pycache__",
"**/node_modules",
+ "dogfood",
"presentations",
"scripts/pdoc/original.py",
"scripts/pdoc/patched.py",
@@ -39,7 +40,8 @@ typeCheckingMode = "strict"
useLibraryCodeForTypes = true
[tool.pytest.ini_options]
-addopts = "--cov=bioimageio --cov-report=xml -n auto --capture=no --doctest-modules --failed-first"
+addopts = "--cov bioimageio --cov-report xml --cov-append --capture no --doctest-modules --failed-first --ignore dogfood --ignore bioimageio/core/backends --ignore bioimageio/core/weight_converters"
+testpaths = ["bioimageio/core", "tests"]
[tool.ruff]
line-length = 88
diff --git a/scripts/show_diff.py b/scripts/show_diff.py
index 1b0163bb..3e273d79 100644
--- a/scripts/show_diff.py
+++ b/scripts/show_diff.py
@@ -2,14 +2,14 @@
from pathlib import Path
from tempfile import TemporaryDirectory
-import pooch
+import pooch # pyright: ignore[reportMissingTypeStubs]
from bioimageio.core import load_description, save_bioimageio_yaml_only
if __name__ == "__main__":
rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/v0_4_9.bioimageio.yaml"
- local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore
+ local_source = Path(pooch.retrieve(rdf_source, None))
model_as_is = load_description(rdf_source, format_version="discover")
model_latest = load_description(rdf_source, format_version="latest")
print(model_latest.validation_summary)
diff --git a/setup.py b/setup.py
index 2465ff7e..0755ff2d 100644
--- a/setup.py
+++ b/setup.py
@@ -30,8 +30,9 @@
],
packages=find_namespace_packages(exclude=["tests"]),
install_requires=[
- "bioimageio.spec ==0.5.3.5",
+ "bioimageio.spec ==0.5.4.1",
"h5py",
+ "imagecodecs",
"imageio>=2.10",
"loguru",
"numpy",
@@ -41,33 +42,37 @@
"ruyaml",
"tqdm",
"typing-extensions",
- "xarray<2025.3.0",
+ "xarray>=2023.01,<2025.3.0",
],
include_package_data=True,
extras_require={
- "pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"],
- "tensorflow": ["tensorflow", "keras>=2.15"],
+ "pytorch": (
+ pytorch_deps := ["torch>=1.6,<3", "torchvision>=0.21", "keras>=3.0,<4"]
+ ),
+ "tensorflow": ["tensorflow", "keras>=2.15,<4"],
"onnx": ["onnxruntime"],
- "dev": [
- "black",
- # "crick", # currently requires python<=3.9
- "filelock",
- "jupyter",
- "jupyter-black",
- "matplotlib",
- "keras>=3.0",
- "onnxruntime",
- "packaging>=17.0",
- "pre-commit",
- "pdoc",
- "psutil", # parallel pytest with 'pytest -n auto'
- "pyright",
- "pytest-cov",
- "pytest-xdist", # parallel pytest
- "pytest",
- "torch>=1.6",
- "torchvision",
- ],
+ "dev": (
+ pytorch_deps
+ + [
+ "black",
+ "cellpose", # for model testing
+ "jupyter-black",
+ "jupyter",
+ "matplotlib",
+ "monai", # for model testing
+ "onnx",
+ "onnxruntime",
+ "packaging>=17.0",
+ "pdoc",
+ "pre-commit",
+ "pyright==1.1.396",
+ "pytest-cov",
+ "pytest",
+ "segment-anything", # for model testing
+ "timm", # for model testing
+ # "crick", # currently requires python<=3.9
+ ]
+ ),
},
project_urls={
"Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues",
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/conftest.py b/tests/conftest.py
index 253ade2f..32880b05 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,20 +1,23 @@
from __future__ import annotations
import subprocess
-import warnings
from itertools import chain
from typing import Dict, List
from loguru import logger
from pytest import FixtureRequest, fixture
+from bioimageio.core import enable_determinism
from bioimageio.spec import __version__ as bioimageio_spec_version
+enable_determinism()
+
+
try:
import torch
torch_version = tuple(map(int, torch.__version__.split(".")[:2]))
- logger.warning(f"detected torch version {torch_version}.x")
+ logger.warning("detected torch version {}", torch.__version__)
except ImportError:
torch = None
torch_version = None
@@ -29,7 +32,7 @@
try:
import tensorflow # type: ignore
- tf_major_version = int(tensorflow.__version__.split(".")[0]) # type: ignore
+ tf_major_version = int(tensorflow.__version__.split(".")[0])
except ImportError:
tensorflow = None
tf_major_version = None
@@ -41,56 +44,50 @@
skip_tensorflow = tensorflow is None
-warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}")
+logger.warning("testing with bioimageio.spec {}", bioimageio_spec_version)
+
+EXAMPLE_DESCRIPTIONS = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/"
# TODO: use models from new collection on S3
MODEL_SOURCES: Dict[str, str] = {
- "hpa_densenet": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml"
- ),
+ "hpa_densenet": "polite-pig/1.1",
"stardist": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models"
- "/stardist_example_model/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/stardist_example_model/v0_4.bioimageio.yaml"
),
"shape_change": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "upsample_test_model/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/upsample_test_model/v0_4.bioimageio.yaml"
),
"stardist_wrong_shape": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "stardist_example_model/rdf_wrong_shape.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/stardist_example_model/rdf_wrong_shape.yaml"
),
"stardist_wrong_shape2": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "stardist_example_model/rdf_wrong_shape2_v0_4.yaml"
+ EXAMPLE_DESCRIPTIONS
+ + "models/stardist_example_model/rdf_wrong_shape2_v0_4.yaml"
),
"unet2d_diff_output_shape": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_diff_output_shape/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_diff_output_shape/bioimageio.yaml"
),
"unet2d_expand_output_shape": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS
+ + "models/unet2d_nuclei_broad/expand_output_shape.bioimageio.yaml"
),
"unet2d_fixed_shape": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_fixed_shape/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_fixed_shape/v0_4.bioimageio.yaml"
),
"unet2d_keras_tf2": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_keras_tf2/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_keras_tf2/v0_4.bioimageio.yaml"
),
"unet2d_keras": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_keras_tf/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_keras_tf/v0_4.bioimageio.yaml"
),
"unet2d_multi_tensor": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_multi_tensor/v0_4.bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_multi_tensor/bioimageio.yaml"
+ ),
+ "unet2d_nuclei_broad_model_old": (
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_nuclei_broad/v0_4_9.bioimageio.yaml"
),
"unet2d_nuclei_broad_model": (
- "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
- "unet2d_nuclei_broad/bioimageio.yaml"
+ EXAMPLE_DESCRIPTIONS + "models/unet2d_nuclei_broad/bioimageio.yaml"
),
}
@@ -244,6 +241,14 @@ def unet2d_nuclei_broad_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]
+# written as model group to automatically skip on missing torch
+@fixture(
+ scope="session", params=[] if skip_torch else ["unet2d_nuclei_broad_model_old"]
+)
+def unet2d_nuclei_broad_model_old(request: FixtureRequest):
+ return MODEL_SOURCES[request.param]
+
+
# written as model group to automatically skip on missing torch
@fixture(scope="session", params=[] if skip_torch else ["unet2d_diff_output_shape"])
def unet2d_diff_output_shape(request: FixtureRequest):
diff --git a/tests/weight_converter/test_add_weights.py b/tests/test_add_weights.py
similarity index 100%
rename from tests/weight_converter/test_add_weights.py
rename to tests/test_add_weights.py
diff --git a/tests/test_any_model_fixture.py b/tests/test_any_model_fixture.py
index a4cc1bce..77225f18 100644
--- a/tests/test_any_model_fixture.py
+++ b/tests/test_any_model_fixture.py
@@ -3,4 +3,4 @@
def test_model(any_model: str):
summary = load_description_and_validate_format_only(any_model)
- assert summary.status == "passed", summary.format()
+ assert summary.status == "valid-format", summary.display()
diff --git a/tests/test_bioimageio_collection.py b/tests/test_bioimageio_collection.py
new file mode 100644
index 00000000..92f3dd5c
--- /dev/null
+++ b/tests/test_bioimageio_collection.py
@@ -0,0 +1,113 @@
+import os
+from typing import Any, Collection, Dict, Iterable, Mapping, Tuple
+
+import pytest
+import requests
+from pydantic import HttpUrl
+
+from bioimageio.spec import InvalidDescr
+from bioimageio.spec.common import Sha256
+from tests.utils import ParameterSet, expensive_test
+
+BASE_URL = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/"
+
+
+def _get_latest_rdf_sources():
+ entries: Any = requests.get(BASE_URL + "all_versions.json").json()["entries"]
+ ret: Dict[str, Tuple[HttpUrl, Sha256]] = {}
+ for entry in entries:
+ version = entry["versions"][0]
+ ret[f"{entry['concept']}/{version['v']}"] = (
+ HttpUrl(version["source"]), # pyright: ignore[reportCallIssue]
+ Sha256(version["sha256"]),
+ )
+
+ return ret
+
+
+ALL_LATEST_RDF_SOURCES: Mapping[str, Tuple[HttpUrl, Sha256]] = _get_latest_rdf_sources()
+
+
+def yield_bioimageio_yaml_urls() -> Iterable[ParameterSet]:
+ for descr_url, sha in ALL_LATEST_RDF_SOURCES.values():
+ key = (
+ str(descr_url)
+ .replace(BASE_URL, "")
+ .replace("/files/rdf.yaml", "")
+ .replace("/files/bioimageio.yaml", "")
+ )
+ yield pytest.param(descr_url, sha, key, id=key)
+
+
+KNOWN_INVALID: Collection[str] = {
+ "affable-shark/1.1", # onnx weights expect fixed input shape
+ "affectionate-cow/0.1.0", # custom dependencies
+ "ambitious-sloth/1.2", # requires inferno
+ "committed-turkey/1.2", # error deserializing VarianceScaling
+ "creative-panda/1", # error deserializing Conv2D
+ "dazzling-spider/0.1.0", # requires careamics
+ "discreet-rooster/1", # error deserializing VarianceScaling
+ "discreete-rooster/1", # error deserializing VarianceScaling
+ "dynamic-t-rex/1", # needs update to 0.5 for scale_linear with axes processing
+ "easy-going-sauropod/1", # CPU implementation of Conv3D currently only supports the NHWC tensor format.
+ "efficient-chipmunk/1", # needs plantseg
+ "famous-fish/0.1.0", # list index out of range `fl[3]`
+ "greedy-whale/1", # batch size is actually limited to 1
+ "happy-elephant/0.1.0", # list index out of range `fl[3]`
+ "happy-honeybee/0.1.0", # requires biapy
+ "heroic-otter/0.1.0", # requires biapy
+ "humorous-crab/1", # batch size is actually limited to 1
+ "humorous-fox/0.1.0", # requires careamics
+ "humorous-owl/1", # error deserializing GlorotUniform
+ "idealistic-turtle/0.1.0", # requires biapy
+ "impartial-shark/1", # error deserializing VarianceScaling
+ "intelligent-lion/0.1.0", # requires biapy
+ "joyful-deer/1", # needs update to 0.5 for scale_linear with axes processing
+ "merry-water-buffalo/0.1.0", # requires biapy
+ "naked-microbe/1", # unknown layer Convolution2D
+ "noisy-ox/1", # batch size is actually limited to 1
+ "non-judgemental-eagle/1", # error deserializing GlorotUniform
+ "straightforward-crocodile/1", # needs update to 0.5 for scale_linear with axes processing
+ "stupendous-sheep/1.1", # requires relativ import of attachment
+ "stupendous-sheep/1.2",
+ "venomous-swan/0.1.0", # requires biapy
+ "wild-rhino/0.1.0", # requires careamics
+}
+
+
+@pytest.mark.parametrize("descr_url,sha,key", list(yield_bioimageio_yaml_urls()))
+def test_rdf_format_to_populate_cache(
+ descr_url: HttpUrl,
+ sha: Sha256,
+ key: str,
+):
+ """this test is redundant if `test_rdf` runs, but is used in the CI to populate the cache"""
+ if os.environ.get("BIOIMAGEIO_POPULATE_CACHE") != "1":
+ pytest.skip("only runs in CI to populate cache")
+
+ if key in KNOWN_INVALID:
+ pytest.skip("known failure")
+
+ from bioimageio.core import load_description
+
+ _ = load_description(descr_url, sha256=sha, perform_io_checks=True)
+
+
+@expensive_test
+@pytest.mark.parametrize("descr_url,sha,key", list(yield_bioimageio_yaml_urls()))
+def test_rdf(
+ descr_url: HttpUrl,
+ sha: Sha256,
+ key: str,
+):
+ if key in KNOWN_INVALID:
+ pytest.skip("known failure")
+
+ from bioimageio.core import load_description_and_test
+
+ descr = load_description_and_test(descr_url, sha256=sha, stop_early=True)
+
+ assert not isinstance(descr, InvalidDescr), descr.validation_summary.display()
+ assert (
+ descr.validation_summary.status == "passed"
+ ), descr.validation_summary.display()
diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py
index 921ecd9c..2418baa5 100644
--- a/tests/test_bioimageio_spec_version.py
+++ b/tests/test_bioimageio_spec_version.py
@@ -8,7 +8,7 @@
def test_bioimageio_spec_version(conda_cmd: Optional[str]):
if conda_cmd is None:
- pytest.skip("requires mamba")
+ pytest.skip("requires conda")
from importlib.metadata import metadata
diff --git a/tests/test_cli.py b/tests/test_cli.py
index e0828ac6..203677ec 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,4 +1,5 @@
import subprocess
+from pathlib import Path
from typing import Any, List, Sequence
import pytest
@@ -36,11 +37,30 @@ def run_subprocess(
],
["test", "unet2d_nuclei_broad_model"],
["predict", "--example", "unet2d_nuclei_broad_model"],
+ ["update-format", "unet2d_nuclei_broad_model_old"],
+ ["add-weights", "unet2d_nuclei_broad_model", "tmp_path"],
+ ["update-hashes", "unet2d_nuclei_broad_model_old"],
+ ["update-hashes", "unet2d_nuclei_broad_model_old", "--output=stdout"],
],
)
-def test_cli(args: List[str], unet2d_nuclei_broad_model: str):
+def test_cli(
+ args: List[str],
+ unet2d_nuclei_broad_model: str,
+ unet2d_nuclei_broad_model_old: str,
+ tmp_path: Path,
+):
resolved_args = [
- str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg
+ (
+ unet2d_nuclei_broad_model
+ if arg == "unet2d_nuclei_broad_model"
+ else (
+ unet2d_nuclei_broad_model_old
+ if arg == "unet2d_nuclei_broad_model_old"
+ else (
+ arg.replace("tmp_path", str(tmp_path)) if "tmp_path" in arg else arg
+ )
+ )
+ )
for arg in args
]
ret = run_subprocess(["bioimageio", *resolved_args])
diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py
index a0a85f5d..08e9f094 100644
--- a/tests/test_prediction_pipeline.py
+++ b/tests/test_prediction_pipeline.py
@@ -2,12 +2,15 @@
from numpy.testing import assert_array_almost_equal
+from bioimageio.core.common import SupportedWeightsFormat
from bioimageio.spec import load_description
from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04
-from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
+from bioimageio.spec.model.v0_5 import ModelDescr
-def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat):
+def _test_prediction_pipeline(
+ model_package: Path, weights_format: SupportedWeightsFormat
+):
from bioimageio.core._prediction_pipeline import create_prediction_pipeline
from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs
diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py
index 0e241df1..0d2ff9b7 100644
--- a/tests/test_prediction_pipeline_device_management.py
+++ b/tests/test_prediction_pipeline_device_management.py
@@ -1,17 +1,14 @@
from pathlib import Path
+import pytest
from numpy.testing import assert_array_almost_equal
-from bioimageio.core.utils.testing import skip_on
+from bioimageio.core.common import SupportedWeightsFormat
from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04
-from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
+from bioimageio.spec.model.v0_5 import ModelDescr
-class TooFewDevicesException(Exception):
- pass
-
-
-def _test_device_management(model_package: Path, weight_format: WeightsFormat):
+def _test_device_management(model_package: Path, weight_format: SupportedWeightsFormat):
import torch
from bioimageio.core import load_description
@@ -19,7 +16,7 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat):
from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs
if not hasattr(torch, "cuda") or torch.cuda.device_count() == 0:
- raise TooFewDevicesException("Need at least one cuda device for this test")
+ pytest.skip("Need at least one cuda device for this test")
bio_model = load_description(model_package)
assert isinstance(bio_model, (ModelDescr, ModelDescr04))
@@ -52,26 +49,21 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat):
assert_array_almost_equal(out, exp, decimal=4)
-@skip_on(TooFewDevicesException, reason="Too few devices")
def test_device_management_torch(any_torch_model: Path):
_test_device_management(any_torch_model, "pytorch_state_dict")
-@skip_on(TooFewDevicesException, reason="Too few devices")
def test_device_management_torchscript(any_torchscript_model: Path):
_test_device_management(any_torchscript_model, "torchscript")
-@skip_on(TooFewDevicesException, reason="Too few devices")
def test_device_management_onnx(any_onnx_model: Path):
_test_device_management(any_onnx_model, "onnx")
-@skip_on(TooFewDevicesException, reason="Too few devices")
def test_device_management_tensorflow(any_tensorflow_model: Path):
_test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle")
-@skip_on(TooFewDevicesException, reason="Too few devices")
def test_device_management_keras(any_keras_model: Path):
_test_device_management(any_keras_model, "keras_hdf5")
diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py
index e408d220..be87f54b 100644
--- a/tests/test_proc_ops.py
+++ b/tests/test_proc_ops.py
@@ -21,15 +21,17 @@ def tid():
def test_scale_linear(tid: MemberId):
from bioimageio.core.proc_ops import ScaleLinear
- offset = xr.DataArray([1, 2, 42], dims=("c"))
- gain = xr.DataArray([1, 2, 3], dims=("c"))
- data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c"))
+ offset = xr.DataArray([1, 2, 42], dims=("channel",))
+ gain = xr.DataArray([1, 2, 3], dims=("channel",))
+ data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "channel"))
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain)
op(sample)
- expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c"))
+ expected = xr.DataArray(
+ np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "channel")
+ )
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7)
@@ -84,10 +86,10 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId):
op = FixedZeroMeanUnitVariance(
tid,
tid,
- mean=xr.DataArray([3, 4, 5], dims=("c")),
- std=xr.DataArray([2.44948974, 2.44948974, 2.44948974], dims=("c")),
+ mean=xr.DataArray([3, 4, 5], dims=("channel",)),
+ std=xr.DataArray([2.44948974, 2.44948974, 2.44948974], dims=("channel",)),
)
- data = xr.DataArray(np.arange(9).reshape((1, 3, 3)), dims=("b", "c", "x"))
+ data = xr.DataArray(np.arange(9).reshape((1, 3, 3)), dims=("batch", "channel", "x"))
expected = xr.DataArray(
np.array(
[
@@ -98,17 +100,33 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId):
]
]
),
- dims=("b", "c", "x"),
+ dims=("batch", "channel", "x"),
)
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
op(sample)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7)
+def test_zero_mean_unit_variance_fixed2(tid: MemberId):
+ from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance
+
+ np_data = np.arange(9).reshape(3, 3)
+ mean = float(np_data.mean())
+ std = float(np_data.mean())
+ eps = 1.0e-7
+ op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps)
+
+ data = xr.DataArray(np_data, dims=("x", "y"))
+ sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
+ expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y"))
+ op(sample)
+ xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7)
+
+
def test_zero_mean_unit_across_axes(tid: MemberId):
from bioimageio.core.proc_ops import ZeroMeanUnitVariance
- data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y"))
+ data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("channel", "x", "y"))
op = ZeroMeanUnitVariance(
tid,
@@ -120,33 +138,18 @@ def test_zero_mean_unit_across_axes(tid: MemberId):
sample.stat = compute_measures(op.required_measures, [sample])
expected = xr.concat(
- [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c"
+ [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)],
+ dim="channel",
)
op(sample)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7)
-def test_zero_mean_unit_variance_fixed2(tid: MemberId):
- from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance
-
- np_data = np.arange(9).reshape(3, 3)
- mean = float(np_data.mean())
- std = float(np_data.mean())
- eps = 1.0e-7
- op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps)
-
- data = xr.DataArray(np_data, dims=("x", "y"))
- sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
- expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y"))
- op(sample)
- xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7)
-
-
def test_binarize(tid: MemberId):
from bioimageio.core.proc_ops import Binarize
op = Binarize(tid, tid, threshold=14)
- data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c"))
+ data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "channel"))
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
expected = xr.zeros_like(data)
expected[{"x": slice(1, None)}] = 1
@@ -158,7 +161,7 @@ def test_binarize2(tid: MemberId):
from bioimageio.core.proc_ops import Binarize
shape = (3, 32, 32)
- axes = ("c", "y", "x")
+ axes = ("channel", "y", "x")
np_data = np.random.rand(*shape)
data = xr.DataArray(np_data, dims=axes)
@@ -188,7 +191,7 @@ def test_clip(tid: MemberId):
def test_combination_of_op_steps_with_dims_specified(tid: MemberId):
from bioimageio.core.proc_ops import ZeroMeanUnitVariance
- data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y"))
+ data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("channel", "x", "y"))
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
op = ZeroMeanUnitVariance(
tid,
@@ -219,7 +222,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: MemberId):
],
]
),
- dims=("c", "x", "y"),
+ dims=("channel", "x", "y"),
)
op(sample)
@@ -239,7 +242,7 @@ def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]):
from bioimageio.core.proc_ops import ScaleMeanVariance
shape = (3, 32, 46)
- ipt_axes = ("c", "y", "x")
+ ipt_axes = ("channel", "y", "x")
np_data = np.random.rand(*shape)
ipt_data = xr.DataArray(np_data, dims=ipt_axes)
ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes)
@@ -268,7 +271,7 @@ def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str])
axes = None if axes_str is None else tuple(map(AxisId, axes_str))
shape = (3, 32, 46)
- ipt_axes = ("c", "y", "x")
+ ipt_axes = ("channel", "y", "x")
np_data = np.random.rand(*shape)
ipt_data = xr.DataArray(np_data, dims=ipt_axes)
@@ -334,7 +337,7 @@ def test_scale_range_axes(tid: MemberId):
op = ScaleRange(tid, tid, lower_quantile, upper_quantile, eps=eps)
np_data = np.arange(18).reshape((2, 3, 3)).astype("float32")
- data = Tensor.from_xarray(xr.DataArray(np_data, dims=("c", "x", "y")))
+ data = Tensor.from_xarray(xr.DataArray(np_data, dims=("channel", "x", "y")))
sample = Sample(members={tid: data}, stat={}, id=None)
p_low_direct = lower_quantile.compute(sample)
@@ -352,7 +355,7 @@ def test_scale_range_axes(tid: MemberId):
np.testing.assert_allclose(p_up_expected.squeeze(), sample.stat[upper_quantile])
exp_data = (np_data - p_low_expected) / (p_up_expected - p_low_expected + eps)
- expected = xr.DataArray(exp_data, dims=("c", "x", "y"))
+ expected = xr.DataArray(exp_data, dims=("channel", "x", "y"))
op(sample)
# NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct
@@ -363,7 +366,7 @@ def test_sigmoid(tid: MemberId):
from bioimageio.core.proc_ops import Sigmoid
shape = (3, 32, 32)
- axes = ("c", "y", "x")
+ axes = ("channel", "y", "x")
np_data = np.random.rand(*shape)
data = xr.DataArray(np_data, dims=axes)
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py
index 203ca64b..f4eca96b 100644
--- a/tests/test_resource_tests.py
+++ b/tests/test_resource_tests.py
@@ -1,15 +1,4 @@
-from typing import Literal
-
-import pytest
-
-from bioimageio.spec import InvalidDescr
-
-
-@pytest.mark.parametrize("mode", ["seed_only", "full"])
-def test_enable_determinism(mode: Literal["seed_only", "full"]):
- from bioimageio.core import enable_determinism
-
- enable_determinism(mode)
+from bioimageio.spec import InvalidDescr, ValidationContext
def test_error_for_wrong_shape(stardist_wrong_shape: str):
@@ -38,15 +27,10 @@ def test_error_for_wrong_shape2(stardist_wrong_shape2: str):
def test_test_model(any_model: str):
from bioimageio.core._resource_tests import test_model
- summary = test_model(any_model)
- assert summary.status == "passed", summary.format()
-
-
-def test_test_resource(any_model: str):
- from bioimageio.core._resource_tests import test_description
+ with ValidationContext(raise_errors=True):
+ summary = test_model(any_model)
- summary = test_description(any_model)
- assert summary.status == "passed", summary.format()
+ assert summary.status == "passed", summary.display()
def test_loading_description_multiple_times(unet2d_nuclei_broad_model: str):
diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py
index 57e86c5a..dd2823b6 100644
--- a/tests/test_stat_calculators.py
+++ b/tests/test_stat_calculators.py
@@ -1,4 +1,4 @@
-from typing import Tuple, Union
+from typing import Optional, Tuple
import numpy as np
import pytest
@@ -12,6 +12,9 @@
DatasetMean,
DatasetStd,
DatasetVar,
+ SampleMean,
+ SampleStd,
+ SampleVar,
)
from bioimageio.core.tensor import Tensor
@@ -27,18 +30,53 @@ def create_random_dataset(tid: MemberId, axes: Tuple[AxisId, ...]):
return Tensor(data, dims=axes), ds
+@pytest.mark.parametrize(
+ "axes",
+ [
+ (AxisId("x"), AxisId("y")),
+ (AxisId("channel"), AxisId("y")),
+ ],
+)
+def test_sample_mean_var_std_calculator(axes: Optional[Tuple[AxisId, ...]]):
+ tid = MemberId("tensor")
+ d_axes = tuple(map(AxisId, ("batch", "channel", "x", "y")))
+ data, ds = create_random_dataset(tid, d_axes)
+ expected_mean = data[0:1].mean(axes)
+ expected_var = data[0:1].var(axes)
+ expected_std = data[0:1].std(axes)
+
+ calc = MeanVarStdCalculator(tid, axes=axes)
+
+ actual = calc.compute(ds[0])
+ actual_mean = actual[SampleMean(member_id=tid, axes=axes)]
+ actual_var = actual[SampleVar(member_id=tid, axes=axes)]
+ actual_std = actual[SampleStd(member_id=tid, axes=axes)]
+
+ assert_allclose(
+ actual_mean if isinstance(actual_mean, (int, float)) else actual_mean.data,
+ expected_mean.data,
+ )
+ assert_allclose(
+ actual_var if isinstance(actual_var, (int, float)) else actual_var.data,
+ expected_var.data,
+ )
+ assert_allclose(
+ actual_std if isinstance(actual_std, (int, float)) else actual_std.data,
+ expected_std.data,
+ )
+
+
@pytest.mark.parametrize(
"axes",
[
None,
- ("x", "y"),
- ("channel", "y"),
+ (AxisId("batch"), AxisId("channel"), AxisId("x"), AxisId("y")),
],
)
-def test_mean_var_std_calculator(axes: Union[None, str, Tuple[str, ...]]):
+def test_dataset_mean_var_std_calculator(axes: Optional[Tuple[AxisId, ...]]):
tid = MemberId("tensor")
- axes = tuple(map(AxisId, ("batch", "channel", "x", "y")))
- data, ds = create_random_dataset(tid, axes)
+ d_axes = tuple(map(AxisId, ("batch", "channel", "x", "y")))
+ data, ds = create_random_dataset(tid, d_axes)
expected_mean = data.mean(axes)
expected_var = data.var(axes)
expected_std = data.std(axes)
diff --git a/tests/test_tensor.py b/tests/test_tensor.py
index 33163077..c57980bd 100644
--- a/tests/test_tensor.py
+++ b/tests/test_tensor.py
@@ -1,3 +1,5 @@
+from typing import Sequence
+
import numpy as np
import pytest
import xarray as xr
@@ -8,9 +10,19 @@
@pytest.mark.parametrize(
"axes",
- ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"],
+ [
+ "yx",
+ "xy",
+ "cyx",
+ "yxc",
+ "bczyx",
+ "xyz",
+ "xyzc",
+ "bzyxc",
+ ("batch", "channel", "x", "y"),
+ ],
)
-def test_transpose_tensor_2d(axes: str):
+def test_transpose_tensor_2d(axes: Sequence[str]):
tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None)
transposed = tensor.transpose([AxisId(a) for a in axes])
@@ -19,9 +31,18 @@ def test_transpose_tensor_2d(axes: str):
@pytest.mark.parametrize(
"axes",
- ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"],
+ [
+ "zyx",
+ "cyzx",
+ "yzixc",
+ "bczyx",
+ "xyz",
+ "xyzc",
+ "bzyxtc",
+ ("batch", "channel", "x", "y", "z"),
+ ],
)
-def test_transpose_tensor_3d(axes: str):
+def test_transpose_tensor_3d(axes: Sequence[str]):
tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None)
transposed = tensor.transpose([AxisId(a) for a in axes])
assert transposed.ndim == len(axes)
@@ -39,3 +60,8 @@ def test_crop_and_pad():
def test_some_magic_ops():
tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None)
assert tensor + 2 == 2 + tensor
+
+
+def test_shape_attributes():
+ tensor = Tensor.from_numpy(np.random.rand(1, 2, 25, 26), dims=None)
+ assert tensor.shape_tuple == tensor.shape
diff --git a/tests/test_weight_converters.py b/tests/test_weight_converters.py
new file mode 100644
index 00000000..1bf65782
--- /dev/null
+++ b/tests/test_weight_converters.py
@@ -0,0 +1,108 @@
+# type: ignore # TODO enable type checking
+import os
+import zipfile
+from pathlib import Path
+
+import pytest
+
+from bioimageio.spec import load_description
+from bioimageio.spec.model import v0_5
+
+
+def test_pytorch_to_torchscript(any_torch_model, tmp_path):
+ from bioimageio.core import test_model
+ from bioimageio.core.weight_converters.pytorch_to_torchscript import convert
+
+ model_descr = load_description(any_torch_model, perform_io_checks=False)
+ if model_descr.implemented_format_version_tuple[:2] == (0, 4):
+ pytest.skip("cannot convert to old 0.4 format")
+
+ out_path = tmp_path / "weights.pt"
+ ret_val = convert(model_descr, out_path)
+ assert out_path.exists()
+ assert isinstance(ret_val, v0_5.TorchscriptWeightsDescr)
+ assert ret_val.source == out_path
+ model_descr.weights.torchscript = ret_val
+ summary = test_model(model_descr, weight_format="torchscript")
+ assert summary.status == "passed", summary.display()
+
+
+def test_pytorch_to_onnx(convert_to_onnx, tmp_path):
+ from bioimageio.core import test_model
+ from bioimageio.core.weight_converters.pytorch_to_onnx import convert
+
+ model_descr = load_description(convert_to_onnx, format_version="latest")
+ out_path = tmp_path / "weights.onnx"
+ opset_version = 15
+ ret_val = convert(
+ model_descr=model_descr,
+ output_path=out_path,
+ opset_version=opset_version,
+ )
+ assert os.path.exists(out_path)
+ assert isinstance(ret_val, v0_5.OnnxWeightsDescr)
+ assert ret_val.opset_version == opset_version
+ assert ret_val.source == out_path
+
+ model_descr.weights.onnx = ret_val
+ summary = test_model(model_descr, weight_format="onnx")
+ assert summary.status == "passed", summary.display()
+
+
+@pytest.mark.skip()
+def test_keras_to_tensorflow(any_keras_model: Path, tmp_path: Path):
+ from bioimageio.core import test_model
+ from bioimageio.core.weight_converters.keras_to_tensorflow import convert
+
+ out_path = tmp_path / "weights.zip"
+ model_descr = load_description(any_keras_model)
+ ret_val = convert(model_descr, out_path)
+
+ assert out_path.exists()
+ assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr)
+
+ expected_names = {"saved_model.pb", "variables/variables.index"}
+ with zipfile.ZipFile(out_path, "r") as f:
+ names = set([name for name in f.namelist()])
+ assert len(expected_names - names) == 0
+
+ model_descr.weights.keras = ret_val
+ summary = test_model(model_descr, weight_format="keras_hdf5")
+ assert summary.status == "passed", summary.display()
+
+
+# TODO: add tensorflow_to_keras converter
+# def test_tensorflow_to_keras(any_tensorflow_model: Path, tmp_path: Path):
+# from bioimageio.core.weight_converters.tensorflow_to_keras import convert
+
+# model_descr = load_description(any_tensorflow_model)
+# out_path = tmp_path / "weights.h5"
+# ret_val = convert(model_descr, output_path=out_path)
+# assert out_path.exists()
+# assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr)
+# assert ret_val.source == out_path
+
+# model_descr.weights.keras = ret_val
+# summary = test_model(model_descr, weight_format="keras_hdf5")
+# assert summary.status == "passed", summary.display()
+
+
+# @pytest.mark.skip()
+# def test_tensorflow_to_keras_zipped(any_tensorflow_model: Path, tmp_path: Path):
+# from bioimageio.core.weight_converters.tensorflow_to_keras import convert
+
+# out_path = tmp_path / "weights.zip"
+# model_descr = load_description(any_tensorflow_model)
+# ret_val = convert(model_descr, out_path)
+
+# assert out_path.exists()
+# assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr)
+
+# expected_names = {"saved_model.pb", "variables/variables.index"}
+# with zipfile.ZipFile(out_path, "r") as f:
+# names = set([name for name in f.namelist()])
+# assert len(expected_names - names) == 0
+
+# model_descr.weights.keras = ret_val
+# summary = test_model(model_descr, weight_format="keras_hdf5")
+# assert summary.status == "passed", summary.display()
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 00000000..f9116fa5
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,21 @@
+"""utils to test bioimageio.core"""
+
+import os
+from typing import Any, Protocol, Sequence
+
+import pytest
+
+
+class ParameterSet(Protocol):
+ def __init__(self, values: Sequence[Any], marks: Any, id: str) -> None:
+ super().__init__()
+
+
+class test_func(Protocol):
+ def __call__(*args: Any, **kwargs: Any): ...
+
+
+expensive_test = pytest.mark.skipif(
+ os.getenv("RUN_EXPENSIVE_TESTS") != "true",
+ reason="Skipping expensive test (enable by RUN_EXPENSIVE_TESTS='true')",
+)
diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py
deleted file mode 100644
index 65c93f60..00000000
--- a/tests/weight_converter/keras/test_tensorflow.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# type: ignore # TODO enable type checking
-import zipfile
-from pathlib import Path
-
-import pytest
-
-from bioimageio.spec import load_description
-from bioimageio.spec.model.v0_5 import ModelDescr
-
-
-@pytest.mark.skip(
- "tensorflow converter not updated yet"
-) # TODO: test tensorflow converter
-def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path):
- from bioimageio.core.weight_converter.keras import (
- convert_weights_to_tensorflow_saved_model_bundle,
- )
-
- out_path = tmp_path / "weights"
- model = load_description(any_keras_model)
- assert isinstance(model, ModelDescr), model.validation_summary.format()
- ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path)
- assert out_path.exists()
- assert (out_path / "variables").exists()
- assert (out_path / "saved_model.pb").exists()
- assert (
- ret_val == 0
- ) # check for correctness is done in converter and returns 0 if it passes
-
-
-@pytest.mark.skip(
- "tensorflow converter not updated yet"
-) # TODO: test tensorflow converter
-def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path):
- from bioimageio.core.weight_converter.keras import (
- convert_weights_to_tensorflow_saved_model_bundle,
- )
-
- out_path = tmp_path / "weights.zip"
- model = load_description(any_keras_model)
- assert isinstance(model, ModelDescr), model.validation_summary.format()
- ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path)
- assert out_path.exists()
- assert (
- ret_val == 0
- ) # check for correctness is done in converter and returns 0 if it passes
-
- # make sure that the zip package was created correctly
- expected_names = {"saved_model.pb", "variables/variables.index"}
- with zipfile.ZipFile(out_path, "r") as f:
- names = set([name for name in f.namelist()])
- assert len(expected_names - names) == 0
diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py
deleted file mode 100644
index 54f2cdf4..00000000
--- a/tests/weight_converter/torch/test_onnx.py
+++ /dev/null
@@ -1,18 +0,0 @@
-# type: ignore # TODO enable type checking
-import os
-from pathlib import Path
-
-import pytest
-
-
-@pytest.mark.skip("onnx converter not updated yet") # TODO: test onnx converter
-def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path):
- from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx
-
- out_path = tmp_path / "weights.onnx"
- ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3)
- assert os.path.exists(out_path)
- if not pytest.skip_onnx:
- assert (
- ret_val == 0
- ) # check for correctness is done in converter and returns 0 if it passes
diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py
deleted file mode 100644
index e0cee3d8..00000000
--- a/tests/weight_converter/torch/test_torchscript.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# type: ignore # TODO enable type checking
-from pathlib import Path
-
-import pytest
-
-from bioimageio.spec.model import v0_4, v0_5
-
-
-@pytest.mark.skip(
- "torchscript converter not updated yet"
-) # TODO: test torchscript converter
-def test_torchscript_converter(
- any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path
-):
- from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript
-
- out_path = tmp_path / "weights.pt"
- ret_val = convert_weights_to_torchscript(any_torch_model, out_path)
- assert out_path.exists()
- assert (
- ret_val == 0
- ) # check for correctness is done in converter and returns 0 if it passes