diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c890e8df..8ce50a1a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -8,7 +8,7 @@ on: defaults: run: - shell: micromamba-shell {0} + shell: bash -el {0} jobs: black: @@ -22,215 +22,218 @@ jobs: jupyter: true version: "24.3" - test-spec-conda: + populate-cache: runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + outputs: + cache-key: ${{steps.cache-key.outputs.cache-key}} steps: - - uses: actions/checkout@v4 - - name: Install Conda environment with Micromamba - if: matrix.python-version != '3.8' - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/env-wo-python.yaml - create-args: >- - python=${{ matrix.python-version }} - post-cleanup: 'all' - - name: Install py3.8 environment - if: matrix.python-version == '3.8' - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/env-py38.yaml - post-cleanup: 'all' - - name: additional setup - run: pip install --no-deps -e . - - name: Get Date - id: get-date - run: | - echo "date=$(date +'%Y-%b')" - echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT - shell: bash - - uses: actions/cache@v4 - with: - path: bioimageio_cache - key: "test-spec-conda-${{ steps.get-date.outputs.date }}" - - name: pytest-spec-conda - run: pytest --disable-pytest-warnings - env: - BIOIMAGEIO_CACHE_PATH: bioimageio_cache - - test-spec-main: + - name: Get Date + id: get-date + run: | + echo "date=$(date +'%Y-%b')" + echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT + - id: cache-key + run: echo "cache-key=test-${{steps.get-date.outputs.date}}" >> $GITHUB_OUTPUT + - uses: actions/cache/restore@v4 + id: look-up + with: + path: bioimageio_cache + key: ${{steps.cache-key.outputs.cache-key}} + lookup-only: true + - uses: actions/checkout@v4 + if: steps.look-up.outputs.cache-hit != 'true' + - uses: actions/cache@v4 + if: steps.look-up.outputs.cache-hit != 'true' + with: + path: bioimageio_cache + key: ${{steps.cache-key.outputs.cache-key}} + - uses: actions/setup-python@v5 + if: steps.look-up.outputs.cache-hit != 'true' + with: + python-version: '3.12' + cache: 'pip' + - name: Install dependencies + if: steps.look-up.outputs.cache-hit != 'true' + run: | + pip install --upgrade pip + pip install -e .[dev] + - run: pytest --disable-pytest-warnings tests/test_bioimageio_collection.py::test_rdf_format_to_populate_cache + if: steps.look-up.outputs.cache-hit != 'true' + env: + BIOIMAGEIO_POPULATE_CACHE: '1' + BIOIMAGEIO_CACHE_PATH: bioimageio_cache + test: + needs: populate-cache runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.12'] include: + - python-version: '3.8' + conda-env: py38 + spec: conda + - python-version: '3.8' + conda-env: py38 + spec: main + - python-version: '3.9' + conda-env: dev + spec: conda + - python-version: '3.10' + conda-env: dev + spec: conda + - python-version: '3.11' + conda-env: full + spec: main + run-expensive-tests: true + report-coverage: true + save-cache: true - python-version: '3.12' - is-dev-version: true + conda-env: dev + spec: conda + - python-version: '3.13' + conda-env: dev + spec: main + save-cache: true + steps: - uses: actions/checkout@v4 - - name: Install Conda environment with Micromamba - if: matrix.python-version != '3.8' - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/env-wo-python.yaml - create-args: >- - python=${{ matrix.python-version }} - post-cleanup: 'all' - - name: Install py3.8 environment - if: matrix.python-version == '3.8' - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/env-py38.yaml - post-cleanup: 'all' - - name: additional setup spec + - id: setup + run: | + echo "env-name=${{ matrix.spec }}-${{ matrix.conda-env }}-${{ matrix.python-version }}" + echo "env-name=${{ matrix.spec }}-${{ matrix.conda-env }}-${{ matrix.python-version }}" >> $GITHUB_OUTPUT + echo "env-file=dev/env-${{ matrix.conda-env }}.yaml" + echo "env-file=dev/env-${{ matrix.conda-env }}.yaml" >> $GITHUB_OUTPUT + - name: check on env-file + shell: python run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - - name: additional setup core - run: pip install --no-deps -e . + from pathlib import Path + from pprint import pprint + if not (env_path:=Path("${{steps.setup.outputs.env-file}}")).exists(): + if env_path.parent.exists(): + pprint(env_path.parent.glob("*")) + else: + pprint(Path().glob("*")) + raise FileNotFoundError(f"{env_path} does not exist") + + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + auto-activate-base: true + activate-environment: ${{steps.setup.outputs.env-name}} + channel-priority: strict + miniforge-version: latest - name: Get Date id: get-date run: | - echo "date=$(date +'%Y-%b')" - echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT - shell: bash - - uses: actions/cache@v4 + echo "today=$(date -u '+%Y%m%d')" + echo "today=$(date -u '+%Y%m%d')" >> $GITHUB_OUTPUT + - name: Restore cached env + uses: actions/cache/restore@v4 + with: + path: ${{env.CONDA}}/envs/${{steps.setup.outputs.env-name}} + key: >- + conda-${{runner.os}}-${{runner.arch}} + -${{steps.get-date.outputs.today}} + -${{hashFiles(steps.setup.outputs.env-file)}} + -${{env.CACHE_NUMBER}} + env: + CACHE_NUMBER: 0 + id: cache-env + - name: Install env + run: conda env update --name=${{steps.setup.outputs.env-name}} --file=${{steps.setup.outputs.env-file}} python=${{matrix.python-version}} + if: steps.cache-env.outputs.cache-hit != 'true' + - name: Install uncached pip dependencies + run: | + pip install --upgrade pip + pip install --no-deps -e . + - name: Install uncached pip dependencies for 'full' environment + if: matrix.conda-env == 'full' + run: | + pip install git+https://github.com/ChaoningZhang/MobileSAM.git + - name: Cache env + if: steps.cache-env.outputs.cache-hit != 'true' + uses: actions/cache/save@v4 + with: + path: ${{env.CONDA}}/envs/${{steps.setup.outputs.env-name}} + key: >- + conda-${{runner.os}}-${{runner.arch}} + -${{steps.get-date.outputs.today}} + -${{hashFiles(steps.setup.outputs.env-file)}} + -${{env.CACHE_NUMBER}} + env: + CACHE_NUMBER: 0 + - run: conda list + - name: Pyright + run: | + pyright --version + pyright -p pyproject.toml --pythonversion ${{ matrix.python-version }} + if: matrix.run-expensive-tests + - name: Restore bioimageio cache + uses: actions/cache/restore@v4 + id: bioimageio-cache with: path: bioimageio_cache - key: "test-spec-main-${{ steps.get-date.outputs.date }}" - - name: pytest-spec-main + key: ${{needs.populate-cache.outputs.cache-key}}${{matrix.run-expensive-tests && '' || '-light'}} + - name: pytest run: pytest --disable-pytest-warnings env: BIOIMAGEIO_CACHE_PATH: bioimageio_cache - - if: matrix.is-dev-version && github.event_name == 'pull_request' + RUN_EXPENSIVE_TESTS: ${{ matrix.run-expensive-tests && 'true' || 'false' }} + - name: Save updated bioimageio cache + if: matrix.save-cache + uses: actions/cache/save@v4 + with: + path: bioimageio_cache + key: ${{needs.populate-cache.outputs.cache-key}}${{matrix.run-expensive-tests && '' || '-light'}} + - if: matrix.report-coverage && github.event_name == 'pull_request' uses: orgoro/coverage@v3.2 with: coverageFile: coverage.xml - token: ${{ secrets.GITHUB_TOKEN }} - - if: matrix.is-dev-version && github.ref == 'refs/heads/main' + token: ${{secrets.GITHUB_TOKEN}} + - if: matrix.report-coverage && github.ref == 'refs/heads/main' run: | pip install genbadge[coverage] genbadge coverage --input-file coverage.xml --output-file ./dist/coverage/coverage-badge.svg coverage html -d dist/coverage - - if: matrix.is-dev-version && github.ref == 'refs/heads/main' + - if: matrix.report-coverage && github.ref == 'refs/heads/main' uses: actions/upload-artifact@v4 with: name: coverage retention-days: 1 path: dist - - test-spec-main-tf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.9', '3.12'] - steps: - - uses: actions/checkout@v4 - - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/env-tf.yaml - condarc: | - channel-priority: flexible - create-args: >- - python=${{ matrix.python-version }} - post-cleanup: 'all' - - name: additional setup spec - run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - - name: additional setup core - run: pip install --no-deps -e . - - name: Get Date - id: get-date - run: | - echo "date=$(date +'%Y-%b')" - echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT - shell: bash - - uses: actions/cache@v4 - with: - path: bioimageio_cache - key: "test-spec-main-tf-${{ steps.get-date.outputs.date }}" - - run: pytest --disable-pytest-warnings - env: - BIOIMAGEIO_CACHE_PATH: bioimageio_cache - - test-spec-conda-tf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.9', '3.12'] - steps: - - uses: actions/checkout@v4 - - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/env-tf.yaml - condarc: | - channel-priority: flexible - create-args: >- - python=${{ matrix.python-version }} - post-cleanup: 'all' - - name: additional setup - run: pip install --no-deps -e . - - name: Get Date - id: get-date - run: | - echo "date=$(date +'%Y-%b')" - echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT - shell: bash - - uses: actions/cache@v4 - with: - path: bioimageio_cache - key: "test-spec-conda-tf-${{ steps.get-date.outputs.date }}" - - name: pytest-spec-tf - run: pytest --disable-pytest-warnings - env: - BIOIMAGEIO_CACHE_PATH: bioimageio_cache - conda-build: + needs: test runs-on: ubuntu-latest - needs: test-spec-conda steps: - - name: checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Install Conda environment with Micromamba - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-name: build-env - condarc: | - channels: - - conda-forge - create-args: | - boa - - name: linux conda build - run: | - conda mambabuild -c conda-forge conda-recipe + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + auto-activate-base: true + activate-environment: "" + channel-priority: strict + miniforge-version: latest + conda-solver: libmamba + - name: install common conda dependencies + run: conda install -n base -c conda-forge conda-build -y + - uses: actions/cache@v4 + with: + path: | + pkgs/noarch + pkgs/channeldata.json + key: ${{ github.sha }}-packages + - name: linux conda build test + shell: bash -l {0} + run: | + mkdir -p ./pkgs/noarch + conda-build -c conda-forge conda-recipe --no-test --output-folder ./pkgs docs: - needs: [test-spec-main] + needs: test if: github.ref == 'refs/heads/main' runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 diff --git a/.gitignore b/.gitignore index d8be60be..688e4a88 100644 --- a/.gitignore +++ b/.gitignore @@ -6,9 +6,12 @@ __pycache__/ *.egg-info/ *.pyc **/tmp +bioimageio_unzipped_tf_weights/ build/ cache coverage.xml dist/ docs/ +dogfood/ typings/pooch/ +bioimageio_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ef0eba58..2bee435e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: v0.3.2 hooks: - id: ruff - args: [--fix] + args: [--fix, --show-fixes] - repo: local hooks: - id: pyright diff --git a/README.md b/README.md index 92bcb9b0..1a1aa932 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,7 @@ functionality, but not any functionality depending on model prediction. To install additional deep learning libraries add `pytorch`, `onnxruntime`, `keras` or `tensorflow`. Deeplearning frameworks to consider installing alongside `bioimageio.core`: + - [Pytorch/Torchscript](https://pytorch.org/get-started/locally/) - [TensorFlow](https://www.tensorflow.org/install) - [ONNXRuntime](https://onnxruntime.ai/docs/install/#python-installs) @@ -297,13 +298,16 @@ These models are described by---and can be loaded with---the bioimageio.spec pac In addition bioimageio.core provides functionality to convert model weight formats. ### Documentation + [Here you find the bioimageio.core documentation.](https://bioimage-io.github.io/core-bioimage-io-python/bioimageio/core.html) #### Presentations + - [Create a model from scratch](https://bioimage-io.github.io/core-bioimage-io-python/presentations/create_ambitious_sloth.slides.html) ([source](https://github.com/bioimage-io/core-bioimage-io-python/tree/main/presentations)) #### Examples -<dl> + +<dl> <dt>Notebooks that save and load resource descriptions and validate their format (using <a href="https://bioimage-io.github.io/core-bioimage-io-python/bioimageio/spec.html">bioimageio.spec</a>, a dependency of bioimageio.core)</dt> <dd><a href="https://github.com/bioimage-io/spec-bioimage-io/blob/main/example/load_model_and_create_your_own.ipynb">load_model_and_create_your_own.ipynb</a> <a target="_blank" href="https://colab.research.google.com/github/bioimage-io/spec-bioimage-io/blob/main/example/load_model_and_create_your_own.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> @@ -327,7 +331,6 @@ bioimageio For examples see [Get started](#get-started). - ### CLI inputs from file For convenience the command line options (not arguments) may be given in a `bioimageio-cli.json` @@ -342,7 +345,6 @@ blockwise: true stats: inputs/dataset_statistics.json ``` - ## Set up Development Environment To set up a development conda environment run the following commands: @@ -355,15 +357,21 @@ pip install -e . --no-deps There are different environment files available that only install tensorflow or pytorch as dependencies, see [dev folder](https://github.com/bioimage-io/core-bioimage-io-python/tree/main/dev). - ## Logging level `bioimageio.spec` and `bioimageio.core` use [loguru](https://github.com/Delgan/loguru) for logging, hence the logging level may be controlled with the `LOGURU_LEVEL` environment variable. - ## Changelog +### 0.7.1 (to be released) + +- breaking: removed `decimals` argument from bioimageio CLI and `bioimageio.core.commands.test()` +- New feature: `bioimageio.core.test_description` accepts **runtime_env** and **run_command** to test a resource + using the conda environment described by that resource (or another specified conda env) +- new CLI command: `bioimageio add-weights` (and utility function: bioimageio.core.add_weights) +- removed `bioimageio.core.proc_ops.get_proc_class` in favor of `bioimageio.core.proc_ops.get_proc` + ### 0.7.0 - breaking: diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 613cd85d..c7554372 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -41,6 +41,7 @@ ) from ._settings import settings from .axis import Axis, AxisId +from .backends import create_model_adapter from .block_meta import BlockMeta from .common import MemberId from .prediction import predict, predict_many @@ -49,6 +50,7 @@ from .stat_measures import Stat from .tensor import Tensor from .utils import VERSION +from .weight_converters import add_weights __version__ = VERSION @@ -63,6 +65,7 @@ __all__ = [ "__version__", + "add_weights", "axis", "Axis", "AxisId", @@ -73,6 +76,7 @@ "commands", "common", "compute_dataset_measures", + "create_model_adapter", "create_prediction_pipeline", "digest_spec", "dump_description", diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 9da63bf5..436448f5 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -1,4 +1,4 @@ -from bioimageio.core.cli import Bioimageio +from .cli import Bioimageio def main(): diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 9f5ccc3e..0b7717aa 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -12,14 +12,21 @@ Union, ) +from loguru import logger from tqdm import tqdm from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 -from bioimageio.spec.model.v0_5 import WeightsFormat from ._op_base import BlockedOperator from .axis import AxisId, PerAxis -from .common import Halo, MemberId, PerMember, SampleId +from .common import ( + BlocksizeParameter, + Halo, + MemberId, + PerMember, + SampleId, + SupportedWeightsFormat, +) from .digest_spec import ( get_block_transform, get_input_halo, @@ -43,7 +50,8 @@ class PredictionPipeline: """ Represents model computation including preprocessing and postprocessing - Note: Ideally use the PredictionPipeline as a context manager + Note: Ideally use the `PredictionPipeline` in a with statement + (as a context manager). """ def __init__( @@ -54,13 +62,20 @@ def __init__( preprocessing: List[Processing], postprocessing: List[Processing], model_adapter: ModelAdapter, - default_ns: Union[ - v0_5.ParameterizedSize_N, - Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], - ] = 10, + default_ns: Optional[BlocksizeParameter] = None, + default_blocksize_parameter: BlocksizeParameter = 10, default_batch_size: int = 1, ) -> None: + """Use `create_prediction_pipeline` to create a `PredictionPipeline`""" super().__init__() + default_blocksize_parameter = default_ns or default_blocksize_parameter + if default_ns is not None: + warnings.warn( + "Argument `default_ns` is deprecated in favor of" + + " `default_blocksize_paramter` and will be removed soon." + ) + del default_ns + if model_description.run_mode: warnings.warn( f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" @@ -88,7 +103,7 @@ def __init__( ) self._block_transform = get_block_transform(model_description) - self._default_ns = default_ns + self._default_blocksize_parameter = default_blocksize_parameter self._default_batch_size = default_batch_size self._input_ids = get_member_ids(model_description.inputs) @@ -121,19 +136,9 @@ def predict_sample_block( self.apply_preprocessing(sample_block) output_meta = sample_block.get_transformed_meta(self._block_transform) - output = output_meta.with_data( - { - tid: out - for tid, out in zip( - self._output_ids, - self._adapter.forward( - *(sample_block.members.get(t) for t in self._input_ids) - ), - ) - if out is not None - }, - stat=sample_block.stat, - ) + local_output = self._adapter.forward(sample_block) + + output = output_meta.with_data(local_output.members, stat=local_output.stat) if not skip_postprocessing: self.apply_postprocessing(output) @@ -152,20 +157,7 @@ def predict_sample_without_blocking( if not skip_preprocessing: self.apply_preprocessing(sample) - output = Sample( - members={ - out_id: out - for out_id, out in zip( - self._output_ids, - self._adapter.forward( - *(sample.members.get(in_id) for in_id in self._input_ids) - ), - ) - if out is not None - }, - stat=sample.stat, - id=sample.id, - ) + output = self._adapter.forward(sample) if not skip_postprocessing: self.apply_postprocessing(output) @@ -197,9 +189,15 @@ def predict_sample_with_fixed_blocking( ) input_blocks = list(input_blocks) predicted_blocks: List[SampleBlock] = [] + logger.info( + "split sample shape {} into {} blocks of {}.", + {k: dict(v) for k, v in sample.shape.items()}, + n_blocks, + {k: dict(v) for k, v in input_block_shape.items()}, + ) for b in tqdm( input_blocks, - desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", + desc=f"predict {sample.id or ''} with {self.model_description.id or self.model_description.name}", unit="block", unit_divisor=1, total=n_blocks, @@ -238,7 +236,7 @@ def predict_sample_with_blocking( + " Consider using `predict_sample_with_fixed_blocking`" ) - ns = ns or self._default_ns + ns = ns or self._default_blocksize_parameter if isinstance(ns, int): ns = { (ipt.id, a.id): ns @@ -319,18 +317,16 @@ def create_prediction_pipeline( bioimageio_model: AnyModelDescr, *, devices: Optional[Sequence[str]] = None, - weight_format: Optional[WeightsFormat] = None, - weights_format: Optional[WeightsFormat] = None, + weight_format: Optional[SupportedWeightsFormat] = None, + weights_format: Optional[SupportedWeightsFormat] = None, dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), keep_updating_initial_dataset_statistics: bool = False, fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( {} ), model_adapter: Optional[ModelAdapter] = None, - ns: Union[ - v0_5.ParameterizedSize_N, - Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], - ] = 10, + ns: Optional[BlocksizeParameter] = None, + default_blocksize_parameter: BlocksizeParameter = 10, **deprecated_kwargs: Any, ) -> PredictionPipeline: """ @@ -340,9 +336,33 @@ def create_prediction_pipeline( * model prediction * computation of output statistics * postprocessing + + Args: + bioimageio_model: A bioimageio model description. + devices: (optional) + weight_format: deprecated in favor of **weights_format** + weights_format: (optional) Use a specific **weights_format** rather than + choosing one automatically. + A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be + created to run inference with the **bioimageio_model**. + dataset_for_initial_statistics: (optional) If preprocessing steps require input + dataset statistics, **dataset_for_initial_statistics** allows you to + specifcy a dataset from which these statistics are computed. + keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want + to update dataset statistics with each processed sample. + fixed_dataset_statistics: (optional) Allows you to specify a mapping of + `DatasetMeasure`s to precomputed `MeasureValue`s. + model_adapter: (optional) Allows you to use a custom **model_adapter** instead + of creating one according to the present/selected **weights_format**. + ns: deprecated in favor of **default_blocksize_parameter** + default_blocksize_parameter: Allows to control the default block size for + blockwise predictions, see `BlocksizeParameter`. + """ weights_format = weight_format or weights_format del weight_format + default_blocksize_parameter = ns or default_blocksize_parameter + del ns if deprecated_kwargs: warnings.warn( f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" @@ -377,5 +397,5 @@ def dataset(): model_adapter=model_adapter, preprocessing=preprocessing, postprocessing=postprocessing, - default_ns=ns, + default_blocksize_parameter=default_blocksize_parameter, ) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 6ace6d5c..327e540a 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -1,21 +1,53 @@ -import traceback +import hashlib +import os +import platform +import subprocess import warnings +from io import StringIO from itertools import product -from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import ( + Callable, + Dict, + Hashable, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, + overload, +) -import numpy as np from loguru import logger +from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args from bioimageio.spec import ( + BioimageioCondaEnv, InvalidDescr, + LatestResourceDescr, ResourceDescr, + ValidationContext, build_description, dump_description, + get_conda_env, load_description, + save_bioimageio_package, ) +from bioimageio.spec._description_impl import DISCOVER from bioimageio.spec._internal.common_nodes import ResourceDescrBase -from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource -from bioimageio.spec.get_conda_env import get_conda_env +from bioimageio.spec._internal.io import is_yaml_value +from bioimageio.spec._internal.io_utils import read_yaml, write_yaml +from bioimageio.spec._internal.types import ( + AbsoluteTolerance, + FormatVersionPlaceholder, + MismatchedElementsPerMillion, + RelativeTolerance, +) +from bioimageio.spec._internal.validation_context import get_validation_context +from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat from bioimageio.spec.summary import ( @@ -27,19 +59,38 @@ from ._prediction_pipeline import create_prediction_pipeline from .axis import AxisId, BatchSize +from .common import MemberId, SupportedWeightsFormat from .digest_spec import get_test_inputs, get_test_outputs from .sample import Sample from .utils import VERSION -def enable_determinism(mode: Literal["seed_only", "full"]): +class DeprecatedKwargs(TypedDict): + absolute_tolerance: NotRequired[AbsoluteTolerance] + relative_tolerance: NotRequired[RelativeTolerance] + decimal: NotRequired[Optional[int]] + + +def enable_determinism( + mode: Literal["seed_only", "full"] = "full", + weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, +): """Seed and configure ML frameworks for maximum reproducibility. May degrade performance. Only recommended for testing reproducibility! Seed any random generators and (if **mode**=="full") request ML frameworks to use deterministic algorithms. + + Args: + mode: determinism mode + - 'seed_only' -- only set seeds, or + - 'full' determinsm features (might degrade performance or throw exceptions) + weight_formats: Limit deep learning importing deep learning frameworks + based on weight_formats. + E.g. this allows to avoid importing tensorflow when testing with pytorch. + Notes: - - **mode** == "full" might degrade performance and throw exceptions. + - **mode** == "full" might degrade performance or throw exceptions. - Subsequent inference calls might still differ. Call before each function (sequence) that is expected to be reproducible. - Degraded performance: Use for testing reproducibility only! @@ -58,120 +109,389 @@ def enable_determinism(mode: Literal["seed_only", "full"]): except Exception as e: logger.debug(str(e)) - try: + if ( + weight_formats is None + or "pytorch_state_dict" in weight_formats + or "torchscript" in weight_formats + ): try: - import torch - except ImportError: - pass - else: - _ = torch.manual_seed(0) - torch.use_deterministic_algorithms(mode == "full") - except Exception as e: - logger.debug(str(e)) + try: + import torch + except ImportError: + pass + else: + _ = torch.manual_seed(0) + torch.use_deterministic_algorithms(mode == "full") + except Exception as e: + logger.debug(str(e)) - try: + if ( + weight_formats is None + or "tensorflow_saved_model_bundle" in weight_formats + or "keras_hdf5" in weight_formats + ): try: - import keras - except ImportError: - pass - else: - keras.utils.set_random_seed(0) - except Exception as e: - logger.debug(str(e)) - - try: + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + try: + import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] + except ImportError: + pass + else: + tf.random.set_seed(0) + if mode == "full": + tf.config.experimental.enable_op_determinism() + # TODO: find possibility to switch it off again?? + except Exception as e: + logger.debug(str(e)) + + if weight_formats is None or "keras_hdf5" in weight_formats: try: - import tensorflow as tf # pyright: ignore[reportMissingImports] - except ImportError: - pass - else: - tf.random.seed(0) - if mode == "full": - tf.config.experimental.enable_op_determinism() - # TODO: find possibility to switch it off again?? - except Exception as e: - logger.debug(str(e)) + try: + import keras # pyright: ignore[reportMissingTypeStubs] + except ImportError: + pass + else: + keras.utils.set_random_seed(0) + except Exception as e: + logger.debug(str(e)) def test_model( - source: Union[v0_5.ModelDescr, PermissiveFileSource], - weight_format: Optional[WeightsFormat] = None, + source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], + weight_format: Optional[SupportedWeightsFormat] = None, devices: Optional[List[str]] = None, - absolute_tolerance: float = 1.5e-4, - relative_tolerance: float = 1e-4, - decimal: Optional[int] = None, *, determinism: Literal["seed_only", "full"] = "seed_only", + sha256: Optional[Sha256] = None, + stop_early: bool = False, + **deprecated: Unpack[DeprecatedKwargs], ) -> ValidationSummary: """Test model inference""" return test_description( source, weight_format=weight_format, devices=devices, - absolute_tolerance=absolute_tolerance, - relative_tolerance=relative_tolerance, - decimal=decimal, determinism=determinism, expected_type="model", + sha256=sha256, + stop_early=stop_early, + **deprecated, ) +def default_run_command(args: Sequence[str]): + logger.info("running '{}'...", " ".join(args)) + _ = subprocess.run(args, shell=True, text=True, check=True) + + def test_description( source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], *, - format_version: Union[Literal["discover", "latest"], str] = "discover", - weight_format: Optional[WeightsFormat] = None, + format_version: Union[FormatVersionPlaceholder, str] = "discover", + weight_format: Optional[SupportedWeightsFormat] = None, devices: Optional[Sequence[str]] = None, - absolute_tolerance: float = 1.5e-4, - relative_tolerance: float = 1e-4, - decimal: Optional[int] = None, determinism: Literal["seed_only", "full"] = "seed_only", expected_type: Optional[str] = None, + sha256: Optional[Sha256] = None, + stop_early: bool = False, + runtime_env: Union[ + Literal["currently-active", "as-described"], Path, BioimageioCondaEnv + ] = ("currently-active"), + run_command: Callable[[Sequence[str]], None] = default_run_command, + **deprecated: Unpack[DeprecatedKwargs], ) -> ValidationSummary: - """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" - rd = load_description_and_test( - source, - format_version=format_version, - weight_format=weight_format, - devices=devices, - absolute_tolerance=absolute_tolerance, - relative_tolerance=relative_tolerance, - decimal=decimal, - determinism=determinism, - expected_type=expected_type, + """Test a bioimage.io resource dynamically, + for example run prediction of test tensors for models. + + Args: + source: model description source. + weight_format: Weight format to test. + Default: All weight formats present in **source**. + devices: Devices to test with, e.g. 'cpu', 'cuda'. + Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. + determinism: Modes to improve reproducibility of test outputs. + expected_type: Assert an expected resource description `type`. + sha256: Expected SHA256 value of **source**. + (Ignored if **source** already is a loaded `ResourceDescr` object.) + stop_early: Do not run further subtests after a failed one. + runtime_env: (Experimental feature!) The Python environment to run the tests in + - `"currently-active"`: Use active Python interpreter. + - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda + environment YAML file based on the model weights description. + - A `BioimageioCondaEnv` or a path to a conda environment YAML file. + Note: The `bioimageio.core` dependency will be added automatically if not present. + run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess + (ignored if **runtime_env** is `"currently-active"`). + """ + if runtime_env == "currently-active": + rd = load_description_and_test( + source, + format_version=format_version, + weight_format=weight_format, + devices=devices, + determinism=determinism, + expected_type=expected_type, + sha256=sha256, + stop_early=stop_early, + **deprecated, + ) + return rd.validation_summary + + if runtime_env == "as-described": + conda_env = None + elif isinstance(runtime_env, (str, Path)): + conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) + elif isinstance(runtime_env, BioimageioCondaEnv): + conda_env = runtime_env + else: + assert_never(runtime_env) + + with TemporaryDirectory(ignore_cleanup_errors=True) as _d: + working_dir = Path(_d) + if isinstance(source, (dict, ResourceDescrBase)): + file_source = save_bioimageio_package( + source, output_path=working_dir / "package.zip" + ) + else: + file_source = source + + return _test_in_env( + file_source, + working_dir=working_dir, + weight_format=weight_format, + conda_env=conda_env, + devices=devices, + determinism=determinism, + expected_type=expected_type, + sha256=sha256, + stop_early=stop_early, + run_command=run_command, + **deprecated, + ) + + +def _test_in_env( + source: PermissiveFileSource, + *, + working_dir: Path, + weight_format: Optional[SupportedWeightsFormat], + conda_env: Optional[BioimageioCondaEnv], + devices: Optional[Sequence[str]], + determinism: Literal["seed_only", "full"], + run_command: Callable[[Sequence[str]], None], + stop_early: bool, + expected_type: Optional[str], + sha256: Optional[Sha256], + **deprecated: Unpack[DeprecatedKwargs], +) -> ValidationSummary: + descr = load_description(source) + + if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise NotImplementedError("Not yet implemented for non-model resources") + + if weight_format is None: + all_present_wfs = [ + wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) + ] + ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] + logger.info( + "Found weight formats {}. Start testing all{}...", + all_present_wfs, + f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", + ) + summary = _test_in_env( + source, + working_dir=working_dir / all_present_wfs[0], + weight_format=all_present_wfs[0], + devices=devices, + determinism=determinism, + conda_env=conda_env, + run_command=run_command, + expected_type=expected_type, + sha256=sha256, + stop_early=stop_early, + **deprecated, + ) + for wf in all_present_wfs[1:]: + additional_summary = _test_in_env( + source, + working_dir=working_dir / wf, + weight_format=wf, + devices=devices, + determinism=determinism, + conda_env=conda_env, + run_command=run_command, + expected_type=expected_type, + sha256=sha256, + stop_early=stop_early, + **deprecated, + ) + for d in additional_summary.details: + # TODO: filter reduntant details; group details + summary.add_detail(d) + return summary + + if weight_format == "pytorch_state_dict": + wf = descr.weights.pytorch_state_dict + elif weight_format == "torchscript": + wf = descr.weights.torchscript + elif weight_format == "keras_hdf5": + wf = descr.weights.keras_hdf5 + elif weight_format == "onnx": + wf = descr.weights.onnx + elif weight_format == "tensorflow_saved_model_bundle": + wf = descr.weights.tensorflow_saved_model_bundle + elif weight_format == "tensorflow_js": + raise RuntimeError( + "testing 'tensorflow_js' is not supported by bioimageio.core" + ) + else: + assert_never(weight_format) + + assert wf is not None + if conda_env is None: + conda_env = get_conda_env(entry=wf) + + # remove name as we crate a name based on the env description hash value + conda_env.name = None + + dumped_env = conda_env.model_dump(mode="json", exclude_none=True) + if not is_yaml_value(dumped_env): + raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") + + env_io = StringIO() + write_yaml(dumped_env, file=env_io) + encoded_env = env_io.getvalue().encode() + env_name = hashlib.sha256(encoded_env).hexdigest() + + try: + run_command(["where" if platform.system() == "Windows" else "which", "conda"]) + except Exception as e: + raise RuntimeError("Conda not available") from e + + working_dir.mkdir(parents=True, exist_ok=True) + try: + run_command(["conda", "activate", env_name]) + except Exception: + path = working_dir / "env.yaml" + _ = path.write_bytes(encoded_env) + logger.debug("written conda env to {}", path) + run_command(["conda", "env", "create", f"--file={path}", f"--name={env_name}"]) + run_command(["conda", "activate", env_name]) + + summary_path = working_dir / "summary.json" + run_command( + [ + "conda", + "run", + "-n", + env_name, + "bioimageio", + "test", + str(source), + f"--summary-path={summary_path}", + f"--determinism={determinism}", + ] + + ([f"--expected-type={expected_type}"] if expected_type else []) + + (["--stop-early"] if stop_early else []) ) - return rd.validation_summary + return ValidationSummary.model_validate_json(summary_path.read_bytes()) +@overload def load_description_and_test( source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], *, - format_version: Union[Literal["discover", "latest"], str] = "discover", - weight_format: Optional[WeightsFormat] = None, + format_version: Literal["latest"], + weight_format: Optional[SupportedWeightsFormat] = None, devices: Optional[Sequence[str]] = None, - absolute_tolerance: float = 1.5e-4, - relative_tolerance: float = 1e-4, - decimal: Optional[int] = None, determinism: Literal["seed_only", "full"] = "seed_only", expected_type: Optional[str] = None, + sha256: Optional[Sha256] = None, + stop_early: bool = False, + **deprecated: Unpack[DeprecatedKwargs], +) -> Union[LatestResourceDescr, InvalidDescr]: ... + + +@overload +def load_description_and_test( + source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], + *, + format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, + weight_format: Optional[SupportedWeightsFormat] = None, + devices: Optional[Sequence[str]] = None, + determinism: Literal["seed_only", "full"] = "seed_only", + expected_type: Optional[str] = None, + sha256: Optional[Sha256] = None, + stop_early: bool = False, + **deprecated: Unpack[DeprecatedKwargs], +) -> Union[ResourceDescr, InvalidDescr]: ... + + +def load_description_and_test( + source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], + *, + format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, + weight_format: Optional[SupportedWeightsFormat] = None, + devices: Optional[Sequence[str]] = None, + determinism: Literal["seed_only", "full"] = "seed_only", + expected_type: Optional[str] = None, + sha256: Optional[Sha256] = None, + stop_early: bool = False, + **deprecated: Unpack[DeprecatedKwargs], ) -> Union[ResourceDescr, InvalidDescr]: - """Test RDF dynamically, e.g. model inference of test inputs""" - if ( - isinstance(source, ResourceDescrBase) - and format_version != "discover" - and source.format_version != format_version - ): - warnings.warn( - f"deserializing source to ensure we validate and test using format {format_version}" - ) - source = dump_description(source) + """Test a bioimage.io resource dynamically, + for example run prediction of test tensors for models. + + See `test_description` for more details. + + Returns: + A (possibly invalid) resource description object + with a populated `.validation_summary` attribute. + """ + if isinstance(source, ResourceDescrBase): + root = source.root + file_name = source.file_name + if ( + ( + format_version + not in ( + DISCOVER, + source.format_version, + ".".join(source.format_version.split(".")[:2]), + ) + ) + or (c := source.validation_summary.details[0].context) is None + or not c.perform_io_checks + ): + logger.debug( + "deserializing source to ensure we validate and test using format {} and perform io checks", + format_version, + ) + source = dump_description(source) + else: + root = Path() + file_name = None if isinstance(source, ResourceDescrBase): rd = source elif isinstance(source, dict): - rd = build_description(source, format_version=format_version) + # check context for a given root; default to root of source + context = get_validation_context( + ValidationContext(root=root, file_name=file_name) + ).replace( + perform_io_checks=True # make sure we perform io checks though + ) + + rd = build_description( + source, + format_version=format_version, + context=context, + ) else: - rd = load_description(source, format_version=format_version) + rd = load_description( + source, format_version=format_version, sha256=sha256, perform_io_checks=True + ) rd.validation_summary.env.add( InstalledPackage(name="bioimageio.core", version=VERSION) @@ -182,50 +502,106 @@ def load_description_and_test( if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): if weight_format is None: - weight_formats: List[WeightsFormat] = [ + weight_formats: List[SupportedWeightsFormat] = [ w for w, we in rd.weights if we is not None ] # pyright: ignore[reportAssignmentType] else: weight_formats = [weight_format] - if decimal is None: - atol = absolute_tolerance - rtol = relative_tolerance - else: - warnings.warn( - "The argument `decimal` has been deprecated in favour of" - + " `relative_tolerance` and `absolute_tolerance`, with different" - + " validation logic, using `numpy.testing.assert_allclose, see" - + " 'https://numpy.org/doc/stable/reference/generated/" - + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" - + " will cause validation to revert to the old behaviour." - ) - atol = 1.5 * 10 ** (-decimal) - rtol = 0 - - enable_determinism(determinism) + enable_determinism(determinism, weight_formats=weight_formats) for w in weight_formats: - _test_model_inference(rd, w, devices, atol, rtol) + _test_model_inference(rd, w, devices, **deprecated) + if stop_early and rd.validation_summary.status == "failed": + break + if not isinstance(rd, v0_4.ModelDescr): - _test_model_inference_parametrized(rd, w, devices) + _test_model_inference_parametrized( + rd, w, devices, stop_early=stop_early + ) + if stop_early and rd.validation_summary.status == "failed": + break # TODO: add execution of jupyter notebooks # TODO: add more tests + if rd.validation_summary.status == "valid-format": + rd.validation_summary.status = "passed" + return rd +def _get_tolerance( + model: Union[v0_4.ModelDescr, v0_5.ModelDescr], + wf: SupportedWeightsFormat, + m: MemberId, + **deprecated: Unpack[DeprecatedKwargs], +) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: + if isinstance(model, v0_5.ModelDescr): + applicable = v0_5.ReproducibilityTolerance() + + # check legacy test kwargs for weight format specific tolerance + if model.config.bioimageio.model_extra is not None: + for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( + "test_kwargs", {} + ).items(): + if wf == weights_format: + applicable = v0_5.ReproducibilityTolerance( + relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), + absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4), + ) + break + + # check for weights format and output tensor specific tolerance + for a in model.config.bioimageio.reproducibility_tolerance: + if (not a.weights_formats or wf in a.weights_formats) and ( + not a.output_ids or m in a.output_ids + ): + applicable = a + break + + rtol = applicable.relative_tolerance + atol = applicable.absolute_tolerance + mismatched_tol = applicable.mismatched_elements_per_million + elif (decimal := deprecated.get("decimal")) is not None: + warnings.warn( + "The argument `decimal` has been deprecated in favour of" + + " `relative_tolerance` and `absolute_tolerance`, with different" + + " validation logic, using `numpy.testing.assert_allclose, see" + + " 'https://numpy.org/doc/stable/reference/generated/" + + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" + + " will cause validation to revert to the old behaviour." + ) + atol = 1.5 * 10 ** (-decimal) + rtol = 0 + mismatched_tol = 0 + else: + # use given (deprecated) test kwargs + atol = deprecated.get("absolute_tolerance", 1e-5) + rtol = deprecated.get("relative_tolerance", 1e-3) + mismatched_tol = 0 + + return rtol, atol, mismatched_tol + + def _test_model_inference( model: Union[v0_4.ModelDescr, v0_5.ModelDescr], - weight_format: WeightsFormat, + weight_format: SupportedWeightsFormat, devices: Optional[Sequence[str]], - atol: float, - rtol: float, + **deprecated: Unpack[DeprecatedKwargs], ) -> None: test_name = f"Reproduce test outputs from test inputs ({weight_format})" - logger.info("starting '{}'", test_name) - error: Optional[str] = None - tb: List[str] = [] + logger.debug("starting '{}'", test_name) + errors: List[ErrorEntry] = [] + + def add_error_entry(msg: str, with_traceback: bool = False): + errors.append( + ErrorEntry( + loc=("weights", weight_format), + msg=msg, + type="bioimageio.core", + with_traceback=with_traceback, + ) + ) try: inputs = get_test_inputs(model) @@ -237,54 +613,66 @@ def _test_model_inference( results = prediction_pipeline.predict_sample_without_blocking(inputs) if len(results.members) != len(expected.members): - error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}" + add_error_entry( + f"Expected {len(expected.members)} outputs, but got {len(results.members)}" + ) else: - for m, exp in expected.members.items(): - res = results.members.get(m) - if res is None: - error = "Output tensors for test case may not be None" + for m, expected in expected.members.items(): + actual = results.members.get(m) + if actual is None: + add_error_entry("Output tensors for test case may not be None") break - try: - np.testing.assert_allclose( - res.data, - exp.data, - rtol=rtol, - atol=atol, + + rtol, atol, mismatched_tol = _get_tolerance( + model, wf=weight_format, m=m, **deprecated + ) + mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs( + expected + ) + mismatched_elements = mismatched.sum().item() + if mismatched_elements / expected.size > mismatched_tol / 1e6: + r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax() + r_max = r_diff[r_max_idx].item() + r_actual = actual[r_max_idx].item() + r_expected = expected[r_max_idx].item() + a_max_idx = abs_diff.argmax() + a_max = abs_diff[a_max_idx].item() + a_actual = actual[a_max_idx].item() + a_expected = expected[a_max_idx].item() + add_error_entry( + f"Output '{m}' disagrees with {mismatched_elements} of" + + f" {expected.size} expected values." + + f"\n Max relative difference: {r_max:.2e}" + + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" + + f" at {r_max_idx}" + + f"\n Max absolute difference: {a_max:.2e}" + + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}" ) - except AssertionError as e: - error = f"Output and expected output disagree:\n {e}" break except Exception as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) + if get_validation_context().raise_errors: + raise e + + add_error_entry(str(e), with_traceback=True) model.validation_summary.add_detail( ValidationDetail( name=test_name, loc=("weights", weight_format), - status="passed" if error is None else "failed", + status="failed" if errors else "passed", recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), - errors=( - [] - if error is None - else [ - ErrorEntry( - loc=("weights", weight_format), - msg=error, - type="bioimageio.core", - traceback=tb, - ) - ] - ), + errors=errors, ) ) def _test_model_inference_parametrized( model: v0_5.ModelDescr, - weight_format: WeightsFormat, + weight_format: SupportedWeightsFormat, devices: Optional[Sequence[str]], + *, + stop_early: bool, ) -> None: if not any( isinstance(a.size, v0_5.ParameterizedSize) @@ -311,11 +699,13 @@ def _test_model_inference_parametrized( # no batch axis batch_sizes = {1} - test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = { - (n, b) for n, b in product(sorted(ns), sorted(batch_sizes)) + test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { + (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) } logger.info( - "Testing inference with {} different input tensor sizes", len(test_cases) + "Testing inference with {} different inputs (B, N): {}", + len(test_cases), + test_cases, ) def generate_test_cases(): @@ -329,7 +719,7 @@ def get_ns(n: int): if isinstance(a.size, v0_5.ParameterizedSize) } - for n, batch_size in sorted(test_cases): + for batch_size, n in sorted(test_cases): input_target_sizes, expected_output_sizes = model.get_axis_sizes( get_ns(n), batch_size=batch_size ) @@ -343,12 +733,14 @@ def get_ns(n: int): resized_test_inputs = Sample( members={ - t.id: test_inputs.members[t.id].resize_to( - { - aid: s - for (tid, aid), s in input_target_sizes.items() - if tid == t.id - }, + t.id: ( + test_inputs.members[t.id].resize_to( + { + aid: s + for (tid, aid), s in input_target_sizes.items() + if tid == t.id + }, + ) ) for t in model.inputs }, @@ -422,9 +814,12 @@ def get_ns(n: int): ), ) ) + if stop_early and error is not None: + break except Exception as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) + if get_validation_context().raise_errors: + raise e + model.validation_summary.add_detail( ValidationDetail( name=f"Run {weight_format} inference for parametrized inputs", @@ -433,9 +828,9 @@ def get_ns(n: int): errors=[ ErrorEntry( loc=("weights", weight_format), - msg=error, + msg=str(e), type="bioimageio.core", - traceback=tb, + with_traceback=True, ) ], ) @@ -458,7 +853,7 @@ def _test_expected_resource_type( ErrorEntry( loc=("type",), type="type", - msg=f"expected type {expected_type}, found {rd.type}", + msg=f"Expected type {expected_type}, found {rd.type}", ) ] ), diff --git a/bioimageio/core/axis.py b/bioimageio/core/axis.py index 34dfa3e1..0b39045e 100644 --- a/bioimageio/core/axis.py +++ b/bioimageio/core/axis.py @@ -8,25 +8,33 @@ from bioimageio.spec.model import v0_5 -def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]): - if a == "b": +def _guess_axis_type(a: str): + if a in ("b", "batch"): return "batch" - elif a == "t": + elif a in ("t", "time"): return "time" - elif a == "i": + elif a in ("i", "index"): return "index" - elif a == "c": + elif a in ("c", "channel"): return "channel" elif a in ("x", "y", "z"): return "space" else: - return "index" # return most unspecific axis + raise ValueError( + f"Failed to infer axis type for axis id '{a}'." + + " Consider using one of: '" + + "', '".join( + ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"] + ) + + "'. Or creating an `Axis` object instead." + ) S = TypeVar("S", bound=str) AxisId = v0_5.AxisId +"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'""" T = TypeVar("T") PerAxis = Mapping[AxisId, T] @@ -42,16 +50,22 @@ class Axis: id: AxisId type: Literal["batch", "channel", "index", "space", "time"] + def __post_init__(self): + if self.type == "batch": + self.id = AxisId("batch") + elif self.type == "channel": + self.id = AxisId("channel") + @classmethod def create(cls, axis: AxisLike) -> Axis: if isinstance(axis, cls): return axis elif isinstance(axis, Axis): return Axis(id=axis.id, type=axis.type) - elif isinstance(axis, str): - return Axis(id=AxisId(axis), type=_get_axis_type(axis)) elif isinstance(axis, v0_5.AxisBase): return Axis(id=AxisId(axis.id), type=axis.type) + elif isinstance(axis, str): + return Axis(id=AxisId(axis), type=_guess_axis_type(axis)) else: assert_never(axis) diff --git a/bioimageio/core/backends/__init__.py b/bioimageio/core/backends/__init__.py new file mode 100644 index 00000000..c39b58b5 --- /dev/null +++ b/bioimageio/core/backends/__init__.py @@ -0,0 +1,3 @@ +from ._model_adapter import create_model_adapter + +__all__ = ["create_model_adapter"] diff --git a/bioimageio/core/backends/_model_adapter.py b/bioimageio/core/backends/_model_adapter.py new file mode 100644 index 00000000..db4d44e9 --- /dev/null +++ b/bioimageio/core/backends/_model_adapter.py @@ -0,0 +1,245 @@ +import sys +import warnings +from abc import ABC, abstractmethod +from typing import ( + Any, + List, + Optional, + Sequence, + Tuple, + Union, + final, +) + +from numpy.typing import NDArray +from typing_extensions import assert_never + +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 + +from ..common import SupportedWeightsFormat +from ..digest_spec import get_axes_infos, get_member_ids +from ..sample import Sample, SampleBlock, SampleBlockWithOrigin +from ..tensor import Tensor + +# Known weight formats in order of priority +# First match wins +DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( + "pytorch_state_dict", + "tensorflow_saved_model_bundle", + "torchscript", + "onnx", + "keras_hdf5", +) + + +class ModelAdapter(ABC): + """ + Represents model *without* any preprocessing or postprocessing. + + ``` + from bioimageio.core import load_description + + model = load_description(...) + + # option 1: + adapter = ModelAdapter.create(model) + adapter.forward(...) + adapter.unload() + + # option 2: + with ModelAdapter.create(model) as adapter: + adapter.forward(...) + ``` + """ + + def __init__(self, model_description: AnyModelDescr): + super().__init__() + self._model_descr = model_description + self._input_ids = get_member_ids(model_description.inputs) + self._output_ids = get_member_ids(model_description.outputs) + self._input_axes = [ + tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs + ] + self._output_axes = [ + tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs + ] + if isinstance(model_description, v0_4.ModelDescr): + self._input_is_optional = [False] * len(model_description.inputs) + else: + self._input_is_optional = [ipt.optional for ipt in model_description.inputs] + + @final + @classmethod + def create( + cls, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + *, + devices: Optional[Sequence[str]] = None, + weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, + ): + """ + Creates model adapter based on the passed spec + Note: All specific adapters should happen inside this function to prevent different framework + initializations interfering with each other + """ + if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError( + f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" + ) + + weights = model_description.weights + errors: List[Exception] = [] + weight_format_priority_order = ( + DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER + if weight_format_priority_order is None + else weight_format_priority_order + ) + # limit weight formats to the ones present + weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ + w for w in weight_format_priority_order if getattr(weights, w) is not None + ] + if not weight_format_priority_order_present: + raise ValueError( + f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" + ) + + for wf in weight_format_priority_order_present: + if wf == "pytorch_state_dict": + assert weights.pytorch_state_dict is not None + try: + from .pytorch_backend import PytorchModelAdapter + + return PytorchModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(e) + elif wf == "tensorflow_saved_model_bundle": + assert weights.tensorflow_saved_model_bundle is not None + try: + from .tensorflow_backend import create_tf_model_adapter + + return create_tf_model_adapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(e) + elif wf == "onnx": + assert weights.onnx is not None + try: + from .onnx_backend import ONNXModelAdapter + + return ONNXModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(e) + elif wf == "torchscript": + assert weights.torchscript is not None + try: + from .torchscript_backend import TorchscriptModelAdapter + + return TorchscriptModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(e) + elif wf == "keras_hdf5": + assert weights.keras_hdf5 is not None + # keras can either be installed as a separate package or used as part of tensorflow + # we try to first import the keras model adapter using the separate package and, + # if it is not available, try to load the one using tf + try: + try: + from .keras_backend import KerasModelAdapter + except Exception: + from .tensorflow_backend import KerasModelAdapter + + return KerasModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(e) + else: + assert_never(wf) + + assert errors + if len(weight_format_priority_order) == 1: + assert len(errors) == 1 + raise errors[0] + + else: + msg = ( + "None of the weight format specific model adapters could be created" + + " in this environment." + ) + if sys.version_info[:2] >= (3, 11): + raise ExceptionGroup(msg, errors) + else: + raise ValueError(msg) from Exception(errors) + + @final + def load(self, *, devices: Optional[Sequence[str]] = None) -> None: + warnings.warn("Deprecated. ModelAdapter is loaded on initialization") + + def forward( + self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] + ) -> Sample: + """ + Run forward pass of model to get model predictions + + Note: sample id and stample stat attributes are passed through + """ + unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] + if unexpected: + warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") + + input_arrays = [ + ( + None + if (a := input_sample.members.get(in_id)) is None + else a.transpose(in_order).data.data + ) + for in_id, in_order in zip(self._input_ids, self._input_axes) + ] + output_arrays = self._forward_impl(input_arrays) + assert len(output_arrays) <= len(self._output_ids) + output_tensors = [ + None if a is None else Tensor(a, dims=d) + for a, d in zip(output_arrays, self._output_axes) + ] + return Sample( + members={ + tid: out + for tid, out in zip( + self._output_ids, + output_tensors, + ) + if out is not None + }, + stat=input_sample.stat, + id=( + input_sample.id + if isinstance(input_sample, Sample) + else input_sample.sample_id + ), + ) + + @abstractmethod + def _forward_impl( + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]: + """framework specific forward implementation""" + + @abstractmethod + def unload(self): + """ + Unload model from any devices, freeing their memory. + The moder adapter should be considered unusable afterwards. + """ + + def _get_input_args_numpy(self, input_sample: Sample): + """helper to extract tensor args as transposed numpy arrays""" + + +create_model_adapter = ModelAdapter.create diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/backends/keras_backend.py similarity index 60% rename from bioimageio/core/model_adapters/_keras_model_adapter.py rename to bioimageio/core/backends/keras_backend.py index e6864ccc..1c10da7d 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/backends/keras_backend.py @@ -1,39 +1,33 @@ import os -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union from loguru import logger from numpy.typing import NDArray -from bioimageio.spec._internal.io_utils import download +from bioimageio.spec._internal.io import download +from bioimageio.spec._internal.type_guards import is_list, is_tuple from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version from .._settings import settings from ..digest_spec import get_axes_infos -from ..tensor import Tensor from ._model_adapter import ModelAdapter os.environ["KERAS_BACKEND"] = settings.keras_backend # by default, we use the keras integrated with tensorflow +# TODO: check if we should prefer keras try: - import tensorflow as tf # pyright: ignore[reportMissingImports] - from tensorflow import ( # pyright: ignore[reportMissingImports] - keras, # pyright: ignore[reportUnknownVariableType] + import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] + from tensorflow import ( # pyright: ignore[reportMissingTypeStubs] + keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] ) - tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] + tf_version = Version(tf.__version__) except Exception: - try: - import keras # pyright: ignore[reportMissingImports] - except Exception as e: - keras = None - keras_error = str(e) - else: - keras_error = None + import keras # pyright: ignore[reportMissingTypeStubs] + tf_version = None -else: - keras_error = None class KerasModelAdapter(ModelAdapter): @@ -43,10 +37,7 @@ def __init__( model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ) -> None: - if keras is None: - raise ImportError(f"failed to import keras: {keras_error}") - - super().__init__() + super().__init__(model_description=model_description) if model_description.weights.keras_hdf5 is None: raise ValueError("model has not keras_hdf5 weights specified") model_tf_version = model_description.weights.keras_hdf5.tensorflow_version @@ -84,22 +75,14 @@ def __init__( for out in model_description.outputs ] - def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - _result: Union[Sequence[NDArray[Any]], NDArray[Any]] - _result = self._network.predict( # pyright: ignore[reportUnknownVariableType] - *[None if t is None else t.data.data for t in input_tensors] - ) - if isinstance(_result, (tuple, list)): - result: Sequence[NDArray[Any]] = _result + def _forward_impl( # pyright: ignore[reportUnknownParameterType] + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ): + network_output = self._network.predict(*input_arrays) # type: ignore + if is_list(network_output) or is_tuple(network_output): + return network_output else: - result = [_result] # type: ignore - - assert len(result) == len(self._output_axes) - ret: List[Optional[Tensor]] = [] - ret.extend( - [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] - ) - return ret + return [network_output] # pyright: ignore[reportUnknownVariableType] def unload(self) -> None: logger.warning( diff --git a/bioimageio/core/backends/onnx_backend.py b/bioimageio/core/backends/onnx_backend.py new file mode 100644 index 00000000..d5b89152 --- /dev/null +++ b/bioimageio/core/backends/onnx_backend.py @@ -0,0 +1,53 @@ +# pyright: reportUnknownVariableType=false +import warnings +from typing import Any, List, Optional, Sequence, Union + +import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] +from numpy.typing import NDArray + +from bioimageio.spec._internal.type_guards import is_list, is_tuple +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..model_adapters import ModelAdapter + + +class ONNXModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + super().__init__(model_description=model_description) + + if model_description.weights.onnx is None: + raise ValueError("No ONNX weights specified for {model_description.name}") + + local_path = download(model_description.weights.onnx.source).path + self._session = rt.InferenceSession(local_path.read_bytes()) + onnx_inputs = self._session.get_inputs() + self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] + + if devices is not None: + warnings.warn( + f"Device management is not implemented for onnx yet, ignoring the devices {devices}" + ) + + def _forward_impl( + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ) -> List[Optional[NDArray[Any]]]: + result: Any = self._session.run( + None, dict(zip(self._input_names, input_arrays)) + ) + if is_list(result) or is_tuple(result): + result_seq = list(result) + else: + result_seq = [result] + + return result_seq + + def unload(self) -> None: + warnings.warn( + "Device management is not implemented for onnx yet, cannot unload model" + ) diff --git a/bioimageio/core/backends/pytorch_backend.py b/bioimageio/core/backends/pytorch_backend.py new file mode 100644 index 00000000..af1ea85d --- /dev/null +++ b/bioimageio/core/backends/pytorch_backend.py @@ -0,0 +1,180 @@ +import gc +import warnings +from contextlib import nullcontext +from io import TextIOWrapper +from pathlib import Path +from typing import Any, List, Literal, Optional, Sequence, Union + +import torch +from loguru import logger +from numpy.typing import NDArray +from torch import nn +from typing_extensions import assert_never + +from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple +from bioimageio.spec.common import ZipPath +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import import_callable +from ._model_adapter import ModelAdapter + + +class PytorchModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: AnyModelDescr, + devices: Optional[Sequence[Union[str, torch.device]]] = None, + mode: Literal["eval", "train"] = "eval", + ): + super().__init__(model_description=model_description) + weights = model_description.weights.pytorch_state_dict + if weights is None: + raise ValueError("No `pytorch_state_dict` weights found") + + devices = get_devices(devices) + self._model = load_torch_model(weights, load_state=True, devices=devices) + if mode == "eval": + self._model = self._model.eval() + elif mode == "train": + self._model = self._model.train() + else: + assert_never(mode) + + self._mode: Literal["eval", "train"] = mode + self._primary_device = devices[0] + + def _forward_impl( + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ) -> List[Optional[NDArray[Any]]]: + tensors = [ + None if a is None else torch.from_numpy(a).to(self._primary_device) + for a in input_arrays + ] + + if self._mode == "eval": + ctxt = torch.no_grad + elif self._mode == "train": + ctxt = nullcontext + else: + assert_never(self._mode) + + with ctxt(): + model_out = self._model(*tensors) + + if is_tuple(model_out) or is_list(model_out): + model_out_seq = model_out + else: + model_out_seq = model_out = [model_out] + + result: List[Optional[NDArray[Any]]] = [] + for i, r in enumerate(model_out_seq): + if r is None: + result.append(None) + elif isinstance(r, torch.Tensor): + r_np: NDArray[Any] = r.detach().cpu().numpy() + result.append(r_np) + elif is_ndarray(r): + result.append(r) + else: + raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.") + + return result + + def unload(self) -> None: + del self._model + _ = gc.collect() # deallocate memory + assert torch is not None + torch.cuda.empty_cache() # release reserved memory + + +def load_torch_model( + weight_spec: Union[ + v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr + ], + *, + load_state: bool = True, + devices: Optional[Sequence[Union[str, torch.device]]] = None, +) -> nn.Module: + custom_callable = import_callable( + weight_spec.architecture, + sha256=( + weight_spec.architecture_sha256 + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) + else weight_spec.sha256 + ), + ) + model_kwargs = ( + weight_spec.kwargs + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) + else weight_spec.architecture.kwargs + ) + torch_model = custom_callable(**model_kwargs) + + if not isinstance(torch_model, nn.Module): + if isinstance( + weight_spec.architecture, + (v0_4.CallableFromFile, v0_4.CallableFromDepencency), + ): + callable_name = weight_spec.architecture.callable_name + else: + callable_name = weight_spec.architecture.callable + + raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.") + + if load_state or devices: + use_devices = get_devices(devices) + torch_model = torch_model.to(use_devices[0]) + if load_state: + torch_model = load_torch_state_dict( + torch_model, + path=download(weight_spec).path, + devices=use_devices, + ) + return torch_model + + +def load_torch_state_dict( + model: nn.Module, + path: Union[Path, ZipPath], + devices: Sequence[torch.device], +) -> nn.Module: + model = model.to(devices[0]) + with path.open("rb") as f: + assert not isinstance(f, TextIOWrapper) + state = torch.load(f, map_location=devices[0], weights_only=True) + + incompatible = model.load_state_dict(state) + if ( + incompatible is not None # pyright: ignore[reportUnnecessaryComparison] + and incompatible.missing_keys + ): + logger.warning("Missing state dict keys: {}", incompatible.missing_keys) + + if ( + incompatible is not None # pyright: ignore[reportUnnecessaryComparison] + and incompatible.unexpected_keys + ): + logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys) + + return model + + +def get_devices( + devices: Optional[Sequence[Union[torch.device, str]]] = None, +) -> List[torch.device]: + if not devices: + torch_devices = [ + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ] + else: + torch_devices = [torch.device(d) for d in devices] + + if len(torch_devices) > 1: + warnings.warn( + f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" + ) + torch_devices = torch_devices[:1] + + return torch_devices diff --git a/bioimageio/core/backends/tensorflow_backend.py b/bioimageio/core/backends/tensorflow_backend.py new file mode 100644 index 00000000..99efe9ef --- /dev/null +++ b/bioimageio/core/backends/tensorflow_backend.py @@ -0,0 +1,212 @@ +from pathlib import Path +from typing import Any, Optional, Sequence, Union + +import numpy as np +import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] +from loguru import logger +from numpy.typing import NDArray + +from bioimageio.core.io import ensure_unzipped +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 + +from ._model_adapter import ModelAdapter + + +class TensorflowModelAdapter(ModelAdapter): + weight_format = "tensorflow_saved_model_bundle" + + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + super().__init__(model_description=model_description) + + weight_file = model_description.weights.tensorflow_saved_model_bundle + if model_description.weights.tensorflow_saved_model_bundle is None: + raise ValueError("No `tensorflow_saved_model_bundle` weights found") + + if devices is not None: + logger.warning( + f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" + ) + + # TODO: check how to load tf weights without unzipping + weight_file = ensure_unzipped( + model_description.weights.tensorflow_saved_model_bundle.source, + Path("bioimageio_unzipped_tf_weights"), + ) + self._network = str(weight_file) + + # TODO currently we relaod the model every time. it would be better to keep the graph and session + # alive in between of forward passes (but then the sessions need to be properly opened / closed) + def _forward_impl( # pyright: ignore[reportUnknownParameterType] + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ): + # TODO read from spec + tag = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.tag_constants.SERVING # pyright: ignore[reportAttributeAccessIssue] + ) + signature_key = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pyright: ignore[reportAttributeAccessIssue] + ) + + graph = tf.Graph() + with graph.as_default(): + with tf.Session( # pyright: ignore[reportAttributeAccessIssue] + graph=graph + ) as sess: # pyright: ignore[reportUnknownVariableType] + # load the model and the signature + graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] + sess, [tag], self._network + ) + signature = ( # pyright: ignore[reportUnknownVariableType] + graph_def.signature_def + ) + + # get the tensors into the graph + in_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].inputs[key].name for key in self._input_ids + ] + out_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].outputs[key].name + for key in self._output_ids + ] + in_tf_tensors = [ + graph.get_tensor_by_name( + name # pyright: ignore[reportUnknownArgumentType] + ) + for name in in_names # pyright: ignore[reportUnknownVariableType] + ] + out_tf_tensors = [ + graph.get_tensor_by_name( + name # pyright: ignore[reportUnknownArgumentType] + ) + for name in out_names # pyright: ignore[reportUnknownVariableType] + ] + + # run prediction + res = sess.run( # pyright: ignore[reportUnknownVariableType] + dict( + zip( + out_names, # pyright: ignore[reportUnknownArgumentType] + out_tf_tensors, + ) + ), + dict(zip(in_tf_tensors, input_arrays)), + ) + # from dict to list of tensors + res = [ # pyright: ignore[reportUnknownVariableType] + res[out] + for out in out_names # pyright: ignore[reportUnknownVariableType] + ] + + return res # pyright: ignore[reportUnknownVariableType] + + def unload(self) -> None: + logger.warning( + "Device management is not implemented for tensorflow 1, cannot unload model" + ) + + +class KerasModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if model_description.weights.tensorflow_saved_model_bundle is None: + raise ValueError("No `tensorflow_saved_model_bundle` weights found") + + super().__init__(model_description=model_description) + if devices is not None: + logger.warning( + f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" + ) + + # TODO: check how to load tf weights without unzipping + weight_file = ensure_unzipped( + model_description.weights.tensorflow_saved_model_bundle.source, + Path("bioimageio_unzipped_tf_weights"), + ) + + try: + self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue] + weight_file, + call_endpoint="serve", + ) + except Exception as e: + try: + self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue] + weight_file, call_endpoint="serving_default" + ) + except Exception as ee: + logger.opt(exception=ee).info( + "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" + ) + raise e + + def _forward_impl( # pyright: ignore[reportUnknownParameterType] + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ): + assert tf is not None + tf_tensor = [ + None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays + ] + + result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] + + assert isinstance(result, dict) + + # TODO: Use RDF's `outputs[i].id` here + result = list( # pyright: ignore[reportUnknownVariableType] + result.values() # pyright: ignore[reportUnknownArgumentType] + ) + + return [ # pyright: ignore[reportUnknownVariableType] + (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) + for r in result # pyright: ignore[reportUnknownVariableType] + ] + + def unload(self) -> None: + logger.warning( + "Device management is not implemented for tensorflow>=2 models" + + f" using `{self.__class__.__name__}`, cannot unload model" + ) + + +def create_tf_model_adapter( + model_description: AnyModelDescr, devices: Optional[Sequence[str]] +): + tf_version = v0_5.Version(tf.__version__) + weights = model_description.weights.tensorflow_saved_model_bundle + if weights is None: + raise ValueError("No `tensorflow_saved_model_bundle` weights found") + + model_tf_version = weights.tensorflow_version + if model_tf_version is None: + logger.warning( + "The model does not specify the tensorflow version." + + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." + ) + elif model_tf_version > tf_version: + logger.warning( + f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." + ) + elif (model_tf_version.major, model_tf_version.minor) != ( + tf_version.major, + tf_version.minor, + ): + logger.warning( + "The tensorflow version specified by the model does not match the installed: " + + f"{model_tf_version} != {tf_version}." + ) + + if tf_version.major <= 1: + return TensorflowModelAdapter( + model_description=model_description, devices=devices + ) + else: + return KerasModelAdapter(model_description=model_description, devices=devices) diff --git a/bioimageio/core/backends/torchscript_backend.py b/bioimageio/core/backends/torchscript_backend.py new file mode 100644 index 00000000..ce3ba131 --- /dev/null +++ b/bioimageio/core/backends/torchscript_backend.py @@ -0,0 +1,74 @@ +# pyright: reportUnknownVariableType=false +import gc +import warnings +from typing import Any, List, Optional, Sequence, Union + +import torch +from numpy.typing import NDArray + +from bioimageio.spec._internal.type_guards import is_list, is_tuple +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..model_adapters import ModelAdapter + + +class TorchscriptModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + super().__init__(model_description=model_description) + if model_description.weights.torchscript is None: + raise ValueError( + f"No torchscript weights found for model {model_description.name}" + ) + + weight_path = download(model_description.weights.torchscript.source).path + if devices is None: + self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] + else: + self.devices = [torch.device(d) for d in devices] + + if len(self.devices) > 1: + warnings.warn( + "Multiple devices for single torchscript model not yet implemented" + ) + + with weight_path.open("rb") as f: + self._model = torch.jit.load(f) + + self._model.to(self.devices[0]) + self._model = self._model.eval() + + def _forward_impl( + self, input_arrays: Sequence[Optional[NDArray[Any]]] + ) -> List[Optional[NDArray[Any]]]: + + with torch.no_grad(): + torch_tensor = [ + None if a is None else torch.from_numpy(a).to(self.devices[0]) + for a in input_arrays + ] + output: Any = self._model.forward(*torch_tensor) + if is_list(output) or is_tuple(output): + output_seq: Sequence[Any] = output + else: + output_seq = [output] + + return [ + ( + None + if r is None + else r.cpu().numpy() if isinstance(r, torch.Tensor) else r + ) + for r in output_seq + ] + + def unload(self) -> None: + self._devices = None + del self._model + _ = gc.collect() # deallocate memory + torch.cuda.empty_cache() # release reserved memory diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index f7740092..4e40c1cf 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -258,6 +258,7 @@ def split_shape_into_blocks( set(block_shape), ) if any(shape[a] < block_shape[a] for a in block_shape): + # TODO: allow larger blockshape raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") assert all(a in shape for a in halo), (tuple(shape), set(halo)) diff --git a/bioimageio/core/cli.py b/bioimageio/core/cli.py index fad44ab3..8e62239d 100644 --- a/bioimageio/core/cli.py +++ b/bioimageio/core/cli.py @@ -8,9 +8,11 @@ import shutil import subprocess import sys +from abc import ABC from argparse import RawTextHelpFormatter from difflib import SequenceMatcher from functools import cached_property +from io import StringIO from pathlib import Path from pprint import pformat, pprint from typing import ( @@ -18,6 +20,7 @@ Dict, Iterable, List, + Literal, Mapping, Optional, Sequence, @@ -27,8 +30,9 @@ Union, ) +import rich.markdown from loguru import logger -from pydantic import BaseModel, Field, model_validator +from pydantic import AliasChoices, BaseModel, Field, model_validator from pydantic_settings import ( BaseSettings, CliPositionalArg, @@ -39,26 +43,30 @@ SettingsConfigDict, YamlConfigSettingsSource, ) -from ruyaml import YAML from tqdm import tqdm from typing_extensions import assert_never -from bioimageio.spec import AnyModelDescr, InvalidDescr, load_description +from bioimageio.spec import ( + AnyModelDescr, + InvalidDescr, + ResourceDescr, + load_description, + save_bioimageio_yaml_only, + settings, + update_format, + update_hashes, +) +from bioimageio.spec._internal.io import is_yaml_value from bioimageio.spec._internal.io_basics import ZipPath +from bioimageio.spec._internal.io_utils import open_bioimageio_yaml from bioimageio.spec._internal.types import NotEmpty from bioimageio.spec.dataset import DatasetDescr from bioimageio.spec.model import ModelDescr, v0_4, v0_5 from bioimageio.spec.notebook import NotebookDescr -from bioimageio.spec.utils import download, ensure_description_is_model - -from .commands import ( - WeightFormatArgAll, - WeightFormatArgAny, - package, - test, - validate_format, -) -from .common import MemberId, SampleId +from bioimageio.spec.utils import download, ensure_description_is_model, write_yaml + +from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test +from .common import MemberId, SampleId, SupportedWeightsFormat from .digest_spec import get_member_ids, load_sample_for_model from .io import load_dataset_stat, save_dataset_stat, save_sample from .prediction import create_prediction_pipeline @@ -71,9 +79,15 @@ ) from .sample import Sample from .stat_measures import Stat -from .utils import VERSION - -yaml = YAML(typ="safe") +from .utils import VERSION, compare +from .weight_converters._add_weights import add_weights + +WEIGHT_FORMAT_ALIASES = AliasChoices( + "weight-format", + "weights-format", + "weight_format", + "weights_format", +) class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): @@ -84,9 +98,31 @@ class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True pass +class WithSummaryLogging(ArgMixin): + summary: Union[ + Literal["display"], Path, Sequence[Union[Literal["display"], Path]] + ] = Field( + "display", + examples=[ + "display", + Path("summary.md"), + Path("bioimageio_summaries/"), + ["display", Path("summary.md")], + ], + ) + """Display the validation summary or save it as JSON, Markdown or HTML. + The format is chosen based on the suffix: `.json`, `.md`, `.html`. + If a folder is given (path w/o suffix) the summary is saved in all formats. + Choose/add `"display"` to render the validation summary to the terminal. + """ + + def log(self, descr: Union[ResourceDescr, InvalidDescr]): + _ = descr.validation_summary.log(self.summary) + + class WithSource(ArgMixin): source: CliPositionalArg[str] - """Url/path to a bioimageio.yaml/rdf.yaml file + """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file or a bioimage.io resource identifier, e.g. 'affable-shark'""" @cached_property @@ -100,29 +136,49 @@ def descr_id(self) -> str: """ if isinstance(self.descr, InvalidDescr): return str(getattr(self.descr, "id", getattr(self.descr, "name"))) - else: - return str( - ( - (bio_config := self.descr.config.get("bioimageio", {})) - and isinstance(bio_config, dict) - and bio_config.get("nickname") - ) - or self.descr.id - or self.descr.name - ) + + nickname = None + if ( + isinstance(self.descr.config, v0_5.Config) + and (bio_config := self.descr.config.bioimageio) + and bio_config.model_extra is not None + ): + nickname = bio_config.model_extra.get("nickname") + + return str(nickname or self.descr.id or self.descr.name) -class ValidateFormatCmd(CmdBase, WithSource): - """validate the meta data format of a bioimageio resource.""" +class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging): + """Validate the meta data format of a bioimageio resource.""" + + perform_io_checks: bool = Field( + settings.perform_io_checks, alias="perform-io-checks" + ) + """Wether or not to perform validations that requires downloading remote files. + Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable. + """ + + @cached_property + def descr(self): + return load_description(self.source, perform_io_checks=self.perform_io_checks) def run(self): - sys.exit(validate_format(self.descr)) + self.log(self.descr) + sys.exit( + 0 + if self.descr.validation_summary.status in ("valid-format", "passed") + else 1 + ) -class TestCmd(CmdBase, WithSource): - """Test a bioimageio resource (beyond meta data formatting)""" +class TestCmd(CmdBase, WithSource, WithSummaryLogging): + """Test a bioimageio resource (beyond meta data formatting).""" - weight_format: WeightFormatArgAll = "all" + weight_format: WeightFormatArgAll = Field( + "all", + alias="weight-format", + validation_alias=WEIGHT_FORMAT_ALIASES, + ) """The weight format to limit testing to. (only relevant for model resources)""" @@ -130,8 +186,24 @@ class TestCmd(CmdBase, WithSource): devices: Optional[Union[str, Sequence[str]]] = None """Device(s) to use for testing""" - decimal: int = 4 - """Precision for numerical comparisons""" + runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field( + "currently-active", alias="runtime-env" + ) + """The python environment to run the tests in + - `"currently-active"`: use active Python interpreter + - `"as-described"`: generate a conda environment YAML file based on the model + weights description. + - A path to a conda environment YAML. + Note: The `bioimageio.core` dependency will be added automatically if not present. + """ + + determinism: Literal["seed_only", "full"] = "seed_only" + """Modes to improve reproducibility of test outputs.""" + + stop_early: bool = Field( + False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x") + ) + """Do not run further subtests after a failed one.""" def run(self): sys.exit( @@ -139,26 +211,32 @@ def run(self): self.descr, weight_format=self.weight_format, devices=self.devices, - decimal=self.decimal, + summary=self.summary, + runtime_env=self.runtime_env, + determinism=self.determinism, ) ) -class PackageCmd(CmdBase, WithSource): - """save a resource's metadata with its associated files.""" +class PackageCmd(CmdBase, WithSource, WithSummaryLogging): + """Save a resource's metadata with its associated files.""" path: CliPositionalArg[Path] """The path to write the (zipped) package to. If it does not have a `.zip` suffix this command will save the package as an unzipped folder instead.""" - weight_format: WeightFormatArgAll = "all" + weight_format: WeightFormatArgAll = Field( + "all", + alias="weight-format", + validation_alias=WEIGHT_FORMAT_ALIASES, + ) """The weight format to include in the package (for model descriptions only).""" def run(self): if isinstance(self.descr, InvalidDescr): - self.descr.validation_summary.display() - raise ValueError("resource description is invalid") + self.log(self.descr) + raise ValueError(f"Invalid {self.descr.type} description.") sys.exit( package( @@ -182,7 +260,7 @@ def _get_stat( req_dataset_meas, _ = get_required_dataset_measures(model_descr) if stats_path.exists(): - logger.info(f"loading precomputed dataset measures from {stats_path}") + logger.info("loading precomputed dataset measures from {}", stats_path) stat = load_dataset_stat(stats_path) for m in req_dataset_meas: if m not in stat: @@ -203,6 +281,110 @@ def _get_stat( return stat +class UpdateCmdBase(CmdBase, WithSource, ABC): + output: Union[Literal["display", "stdout"], Path] = "display" + """Output updated bioimageio.yaml to the terminal or write to a file. + Notes: + - `"display"`: Render to the terminal with syntax highlighting. + - `"stdout"`: Write to sys.stdout without syntax highligthing. + (More convenient for copying the updated bioimageio.yaml from the terminal.) + """ + + diff: Union[bool, Path] = Field(True, alias="diff") + """Output a diff of original and updated bioimageio.yaml. + If a given path has an `.html` extension, a standalone HTML file is written, + otherwise the diff is saved in unified diff format (pure text). + """ + + exclude_unset: bool = Field(True, alias="exclude-unset") + """Exclude fields that have not explicitly be set.""" + + exclude_defaults: bool = Field(False, alias="exclude-defaults") + """Exclude fields that have the default value (even if set explicitly).""" + + @cached_property + def updated(self) -> Union[ResourceDescr, InvalidDescr]: + raise NotImplementedError + + def run(self): + original_yaml = open_bioimageio_yaml(self.source).unparsed_content + assert isinstance(original_yaml, str) + stream = StringIO() + + save_bioimageio_yaml_only( + self.updated, + stream, + exclude_unset=self.exclude_unset, + exclude_defaults=self.exclude_defaults, + ) + updated_yaml = stream.getvalue() + + diff = compare( + original_yaml.split("\n"), + updated_yaml.split("\n"), + diff_format=( + "html" + if isinstance(self.diff, Path) and self.diff.suffix == ".html" + else "unified" + ), + ) + + if isinstance(self.diff, Path): + _ = self.diff.write_text(diff, encoding="utf-8") + elif self.diff: + console = rich.console.Console() + diff_md = f"## Diff\n\n````````diff\n{diff}\n````````" + console.print(rich.markdown.Markdown(diff_md)) + + if isinstance(self.output, Path): + _ = self.output.write_text(updated_yaml, encoding="utf-8") + logger.info(f"written updated description to {self.output}") + elif self.output == "display": + updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```" + rich.console.Console().print(rich.markdown.Markdown(updated_md)) + elif self.output == "stdout": + print(updated_yaml) + else: + assert_never(self.output) + + if isinstance(self.updated, InvalidDescr): + logger.warning("Update resulted in invalid description") + _ = self.updated.validation_summary.display() + + +class UpdateFormatCmd(UpdateCmdBase): + """Update the metadata format to the latest format version.""" + + exclude_defaults: bool = Field(True, alias="exclude-defaults") + """Exclude fields that have the default value (even if set explicitly). + + Note: + The update process sets most unset fields explicitly with their default value. + """ + + perform_io_checks: bool = Field( + settings.perform_io_checks, alias="perform-io-checks" + ) + """Wether or not to attempt validation that may require file download. + If `True` file hash values are added if not present.""" + + @cached_property + def updated(self): + return update_format( + self.source, + exclude_defaults=self.exclude_defaults, + perform_io_checks=self.perform_io_checks, + ) + + +class UpdateHashesCmd(UpdateCmdBase): + """Create a bioimageio.yaml description with updated file hashes.""" + + @cached_property + def updated(self): + return update_hashes(self.source) + + class PredictCmd(CmdBase, WithSource): """Run inference on your data with a bioimage.io model.""" @@ -222,7 +404,7 @@ class PredictCmd(CmdBase, WithSource): Example inputs to process sample 'a' and 'b' for a model expecting a 'raw' and a 'mask' input tensor: - --inputs="[[\"a_raw.tif\",\"a_mask.tif\"],[\"b_raw.tif\",\"b_mask.tif\"]]" + --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]" (Note that JSON double quotes need to be escaped.) Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file @@ -270,7 +452,11 @@ class PredictCmd(CmdBase, WithSource): """preview which files would be processed and what outputs would be generated.""" - weight_format: WeightFormatArgAny = "any" + weight_format: WeightFormatArgAny = Field( + "any", + alias="weight-format", + validation_alias=WEIGHT_FORMAT_ALIASES, + ) """The weight format to use.""" example: bool = False @@ -318,13 +504,15 @@ def _example(self): bioimageio_cli_path = example_path / YAML_FILE stats_file = "dataset_statistics.json" stats = (example_path / stats_file).as_posix() - yaml.dump( - dict( - inputs=inputs, - outputs=output_pattern, - stats=stats_file, - blockwise=self.blockwise, - ), + cli_example_args = dict( + inputs=inputs, + outputs=output_pattern, + stats=stats_file, + blockwise=self.blockwise, + ) + assert is_yaml_value(cli_example_args) + write_yaml( + cli_example_args, bioimageio_cli_path, ) @@ -545,16 +733,49 @@ def input_dataset(stat: Stat): save_sample(sp_out, sample_out) +class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging): + output: CliPositionalArg[Path] + """The path to write the updated model package to.""" + + source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format") + """Exclusively use these weights to convert to other formats.""" + + target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format") + """Exclusively add this weight format.""" + + verbose: bool = False + """Log more (error) output.""" + + def run(self): + model_descr = ensure_description_is_model(self.descr) + if isinstance(model_descr, v0_4.ModelDescr): + raise TypeError( + f"model format {model_descr.format_version} not supported." + + " Please update the model first." + ) + updated_model_descr = add_weights( + model_descr, + output_path=self.output, + source_format=self.source_format, + target_format=self.target_format, + verbose=self.verbose, + ) + if updated_model_descr is None: + return + + self.log(updated_model_descr) + + JSON_FILE = "bioimageio-cli.json" YAML_FILE = "bioimageio-cli.yaml" class Bioimageio( BaseSettings, + cli_implicit_flags=True, cli_parse_args=True, cli_prog_name="bioimageio", cli_use_class_docs_for_groups=True, - cli_implicit_flags=True, use_attribute_docstrings=True, ): """bioimageio - CLI for bioimage.io resources 🦒""" @@ -576,6 +797,16 @@ class Bioimageio( predict: CliSubCommand[PredictCmd] "Predict with a model resource" + update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") + """Update the metadata format""" + + update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") + """Create a bioimageio.yaml description with updated file hashes.""" + + add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") + """Add additional weights to the model descriptions converted from available + formats to improve deployability.""" + @classmethod def settings_customise_sources( cls, @@ -613,7 +844,15 @@ def run(self): "executing CLI command:\n{}", pformat({k: v for k, v in self.model_dump().items() if v is not None}), ) - cmd = self.validate_format or self.test or self.package or self.predict + cmd = ( + self.add_weights + or self.package + or self.predict + or self.test + or self.update_format + or self.update_hashes + or self.validate_format + ) assert cmd is not None cmd.run() diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py index c71d495f..7184014c 100644 --- a/bioimageio/core/commands.py +++ b/bioimageio/core/commands.py @@ -1,4 +1,4 @@ -"""These functions implement the logic of the bioimageio command line interface +"""These functions are used in the bioimageio command line interface defined in `bioimageio.core.cli`.""" from pathlib import Path @@ -6,18 +6,18 @@ from typing_extensions import Literal +from bioimageio.core.common import SupportedWeightsFormat from bioimageio.spec import ( InvalidDescr, ResourceDescr, save_bioimageio_package, save_bioimageio_package_as_folder, ) -from bioimageio.spec.model.v0_5 import WeightsFormat from ._resource_tests import test_description -WeightFormatArgAll = Literal[WeightsFormat, "all"] -WeightFormatArgAny = Literal[WeightsFormat, "any"] +WeightFormatArgAll = Literal[SupportedWeightsFormat, "all"] +WeightFormatArgAny = Literal[SupportedWeightsFormat, "any"] def test( @@ -25,45 +25,53 @@ def test( *, weight_format: WeightFormatArgAll = "all", devices: Optional[Union[str, Sequence[str]]] = None, - decimal: int = 4, + summary: Union[ + Literal["display"], Path, Sequence[Union[Literal["display"], Path]] + ] = "display", + runtime_env: Union[ + Literal["currently-active", "as-described"], Path + ] = "currently-active", + determinism: Literal["seed_only", "full"] = "seed_only", ) -> int: - """test a bioimageio resource + """Test a bioimageio resource. - Args: - source: Path or URL to the bioimageio resource description file - (bioimageio.yaml or rdf.yaml) or to a zipped resource - weight_format: (model only) The weight format to use - devices: Device(s) to use for testing - decimal: Precision for numerical comparisons + Arguments as described in `bioimageio.core.cli.TestCmd` """ if isinstance(descr, InvalidDescr): - descr.validation_summary.display() - return 1 + test_summary = descr.validation_summary + else: + test_summary = test_description( + descr, + weight_format=None if weight_format == "all" else weight_format, + devices=[devices] if isinstance(devices, str) else devices, + runtime_env=runtime_env, + determinism=determinism, + ) - summary = test_description( - descr, - weight_format=None if weight_format == "all" else weight_format, - devices=[devices] if isinstance(devices, str) else devices, - decimal=decimal, - ) - summary.display() - return 0 if summary.status == "passed" else 1 + _ = test_summary.log(summary) + return 0 if test_summary.status == "passed" else 1 def validate_format( descr: Union[ResourceDescr, InvalidDescr], + summary: Union[Path, Sequence[Path]] = (), ): - """validate the meta data format of a bioimageio resource + """DEPRECATED; Access the existing `validation_summary` attribute instead. + validate the meta data format of a bioimageio resource Args: descr: a bioimageio resource description """ - descr.validation_summary.display() - return 0 if descr.validation_summary.status == "passed" else 1 + _ = descr.validation_summary.save(summary) + return 0 if descr.validation_summary.status in ("valid-format", "passed") else 1 +# TODO: absorb into `save_bioimageio_package` def package( - descr: ResourceDescr, path: Path, *, weight_format: WeightFormatArgAll = "all" + descr: ResourceDescr, + path: Path, + *, + weight_format: WeightFormatArgAll = "all", ): """Save a resource's metadata with its associated files. @@ -76,8 +84,12 @@ def package( weight-format: include only this single weight-format (if not 'all'). """ if isinstance(descr, InvalidDescr): - descr.validation_summary.display() - raise ValueError("resource description is invalid") + logged = descr.validation_summary.save() + msg = f"Invalid {descr.type} description." + if logged: + msg += f" Details saved to {logged}." + + raise ValueError(msg) if weight_format == "all": weights_priority_order = None diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 78a85886..9f939061 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -15,6 +15,17 @@ from bioimageio.spec.model import v0_5 +from .axis import AxisId + +SupportedWeightsFormat = Literal[ + "keras_hdf5", + "onnx", + "pytorch_state_dict", + "tensorflow_saved_model_bundle", + "torchscript", +] + + DTypeStr = Literal[ "bool", "float32", @@ -87,7 +98,19 @@ class SliceInfo(NamedTuple): SampleId = Hashable +"""ID of a sample, see `bioimageio.core.sample.Sample`""" MemberId = v0_5.TensorId +"""ID of a `Sample` member, see `bioimageio.core.sample.Sample`""" + +BlocksizeParameter = Union[ + v0_5.ParameterizedSize_N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], +] +""" +Parameter to determine a concrete size for paramtrized axis sizes defined by +`bioimageio.spec.model.v0_5.ParameterizedSize`. +""" + T = TypeVar("T") PerMember = Mapping[MemberId, T] diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index edb5a45d..fb0462f5 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -1,6 +1,9 @@ from __future__ import annotations +import collections.abc +import hashlib import importlib.util +import sys from itertools import chain from pathlib import Path from typing import ( @@ -23,9 +26,8 @@ from numpy.typing import NDArray from typing_extensions import Unpack, assert_never -from bioimageio.spec._internal.io import resolve_and_extract -from bioimageio.spec._internal.io_utils import HashKwargs -from bioimageio.spec.common import FileSource +from bioimageio.spec._internal.io import HashKwargs +from bioimageio.spec.common import FileDescr, FileSource, ZipPath from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile from bioimageio.spec.model.v0_5 import ( @@ -33,9 +35,10 @@ ArchitectureFromLibraryDescr, ParameterizedSize_N, ) -from bioimageio.spec.utils import load_array +from bioimageio.spec.utils import download, load_array -from .axis import AxisId, AxisInfo, AxisLike, PerAxis +from ._settings import settings +from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis from .block_meta import split_multiple_shapes_into_blocks from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks from .io import load_tensor @@ -48,9 +51,16 @@ from .stat_measures import Stat from .tensor import Tensor +TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] + def import_callable( - node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr], + node: Union[ + ArchitectureFromFileDescr, + ArchitectureFromLibraryDescr, + CallableFromDepencency, + CallableFromFile, + ], /, **kwargs: Unpack[HashKwargs], ) -> Callable[..., Any]: @@ -65,7 +75,6 @@ def import_callable( c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) elif isinstance(node, ArchitectureFromFileDescr): c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) - else: assert_never(node) @@ -78,17 +87,70 @@ def import_callable( def _import_from_file_impl( source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] ): - local_file = resolve_and_extract(source, **kwargs) - module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location( - module_name, local_file.path - ) - if importlib_spec is None: - raise ImportError(f"Failed to import {module_name} from {source}.") + src_descr = FileDescr(source=source, **kwargs) + # ensure sha is valid even if perform_io_checks=False + src_descr.validate_sha256() + assert src_descr.sha256 is not None + + local_source = src_descr.download() + + source_bytes = local_source.path.read_bytes() + assert isinstance(source_bytes, bytes) + source_sha = hashlib.sha256(source_bytes).hexdigest() + + # make sure we have unique module name + module_name = f"{local_source.path.stem}_{source_sha}" + + # make sure we have a valid module name + if not module_name.isidentifier(): + module_name = f"custom_module_{source_sha}" + assert module_name.isidentifier(), module_name + + module = sys.modules.get(module_name) + if module is None: + try: + if isinstance(local_source.path, Path): + module_path = local_source.path + elif isinstance(local_source.path, ZipPath): + # save extract source to cache + # loading from a file from disk ensure we get readable tracebacks + # if any errors occur + module_path = ( + settings.cache_path / f"{source_sha}-{local_source.path.name}" + ) + _ = module_path.write_bytes(source_bytes) + else: + assert_never(local_source.path) + + importlib_spec = importlib.util.spec_from_file_location( + module_name, module_path + ) - dep = importlib.util.module_from_spec(importlib_spec) - importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return getattr(dep, callable_name) + if importlib_spec is None: + raise ImportError(f"Failed to import {source}") + + module = importlib.util.module_from_spec(importlib_spec) + assert importlib_spec.loader is not None + importlib_spec.loader.exec_module(module) + + except Exception as e: + raise ImportError(f"Failed to import {source}") from e + else: + sys.modules[module_name] = module # cache this module + + try: + callable_attr = getattr(module, callable_name) + except AttributeError as e: + raise AttributeError( + f"Imported custom module from {source} has no `{callable_name}` attribute." + ) from e + except Exception as e: + raise AttributeError( + f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." + ) from e + + else: + return callable_attr def get_axes_infos( @@ -100,14 +162,15 @@ def get_axes_infos( ], ) -> List[AxisInfo]: """get a unified, simplified axis representation from spec axes""" - return [ - ( - AxisInfo.create("i") - if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") - else AxisInfo.create(a) - ) - for a in io_descr.axes - ] + ret: List[AxisInfo] = [] + for a in io_descr.axes: + if isinstance(a, v0_5.AxisBase): + ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) + else: + assert a in ("b", "i", "t", "c", "z", "y", "x") + ret.append(AxisInfo.create(a)) + + return ret def get_member_id( @@ -308,7 +371,7 @@ def get_io_sample_block_metas( def get_tensor( - src: Union[Tensor, xr.DataArray, NDArray[Any], Path], + src: Union[ZipPath, TensorSource], ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], ): """helper to cast/load various tensor sources""" @@ -322,7 +385,10 @@ def get_tensor( if isinstance(src, np.ndarray): return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) - if isinstance(src, Path): + if isinstance(src, FileDescr): + src = download(src).path + + if isinstance(src, (ZipPath, Path, str)): return load_tensor(src, axes=get_axes_infos(ipt)) assert_never(src) @@ -333,10 +399,7 @@ def create_sample_for_model( *, stat: Optional[Stat] = None, sample_id: SampleId = None, - inputs: Optional[ - PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]] - ] = None, # TODO: make non-optional - **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs` + inputs: Union[PerMember[TensorSource], TensorSource], ) -> Sample: """Create a sample from a single set of input(s) for a specific bioimage.io model @@ -345,9 +408,17 @@ def create_sample_for_model( stat: dictionary with sample and dataset statistics (may be updated in-place!) inputs: the input(s) constituting a single sample. """ - inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()} model_inputs = {get_member_id(d): d for d in model.inputs} + if isinstance(inputs, collections.abc.Mapping): + inputs = {MemberId(k): v for k, v in inputs.items()} + elif len(model_inputs) == 1: + inputs = {list(model_inputs)[0]: inputs} + else: + raise TypeError( + f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" + ) + if unknown := {k for k in inputs if k not in model_inputs}: raise ValueError(f"Got unexpected inputs: {unknown}") diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index ee60a67a..dc5b70db 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,16 +1,35 @@ import collections.abc import warnings +import zipfile +from io import TextIOWrapper from pathlib import Path, PurePosixPath -from typing import Any, Mapping, Optional, Sequence, Tuple, Union +from shutil import copyfileobj +from typing import ( + Any, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) -import h5py +import h5py # pyright: ignore[reportMissingTypeStubs] import numpy as np -from imageio.v3 import imread, imwrite +from imageio.v3 import imread, imwrite # type: ignore from loguru import logger from numpy.typing import NDArray from pydantic import BaseModel, ConfigDict, TypeAdapter - -from bioimageio.spec.utils import load_array, save_array +from typing_extensions import assert_never + +from bioimageio.spec._internal.io import interprete_file_source +from bioimageio.spec.common import ( + HttpUrl, + PermissiveFileSource, + RelativeFilePath, + ZipPath, +) +from bioimageio.spec.utils import download, load_array, save_array from .axis import AxisLike from .common import PerMember @@ -21,29 +40,54 @@ DEFAULT_H5_DATASET_PATH = "data" -def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]: +SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5") + + +def load_image( + source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None +) -> NDArray[Any]: """load a single image as numpy array Args: - path: image path + source: image source is_volume: deprecated """ if is_volume is not None: warnings.warn("**is_volume** is deprecated and will be removed soon.") - file_path, subpath = _split_dataset_path(Path(path)) + if isinstance(source, ZipPath): + parsed_source = source + else: + parsed_source = interprete_file_source(source) - if file_path.suffix == ".npy": + if isinstance(parsed_source, RelativeFilePath): + src = parsed_source.absolute() + else: + src = parsed_source + + # FIXME: why is pyright complaining about giving the union to _split_dataset_path? + if isinstance(src, Path): + file_source, subpath = _split_dataset_path(src) + elif isinstance(src, HttpUrl): + file_source, subpath = _split_dataset_path(src) + elif isinstance(src, ZipPath): + file_source, subpath = _split_dataset_path(src) + else: + assert_never(src) + + path = download(file_source).path + + if path.suffix == ".npy": if subpath is not None: raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}") return load_array(path) - elif file_path.suffix in (".h5", ".hdf", ".hdf5"): + elif path.suffix in SUFFIXES_WITH_DATAPATH: if subpath is None: dataset_path = DEFAULT_H5_DATASET_PATH else: dataset_path = str(subpath) - with h5py.File(file_path, "r") as f: + with h5py.File(path, "r") as f: h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType] dataset_path ) @@ -60,41 +104,81 @@ def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]: image # pyright: ignore[reportUnknownArgumentType] ) return image # pyright: ignore[reportUnknownVariableType] + elif isinstance(path, ZipPath): + return imread( + path.read_bytes(), extension=path.suffix + ) # pyright: ignore[reportUnknownVariableType] else: return imread(path) # pyright: ignore[reportUnknownVariableType] -def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor: +def load_tensor( + path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None +) -> Tensor: # TODO: load axis meta data array = load_image(path) return Tensor.from_numpy(array, dims=axes) -def _split_dataset_path(path: Path) -> Tuple[Path, Optional[PurePosixPath]]: +_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) + + +def _split_dataset_path( + source: _SourceT, +) -> Tuple[_SourceT, Optional[PurePosixPath]]: """Split off subpath (e.g. internal h5 dataset path) from a file path following a file extension. Examples: >>> _split_dataset_path(Path("my_file.h5/dataset")) - (PosixPath('my_file.h5'), PurePosixPath('dataset')) + (...Path('my_file.h5'), PurePosixPath('dataset')) - If no suffix is detected the path is returned with >>> _split_dataset_path(Path("my_plain_file")) - (PosixPath('my_plain_file'), None) + (...Path('my_plain_file'), None) """ - if path.suffix: + if isinstance(source, RelativeFilePath): + src = source.absolute() + else: + src = source + + del source + + def separate_pure_path(path: PurePosixPath): + for p in path.parents: + if p.suffix in SUFFIXES_WITH_DATAPATH: + return p, PurePosixPath(path.relative_to(p)) + return path, None - for p in path.parents: - if p.suffix: - return p, PurePosixPath(path.relative_to(p)) + if isinstance(src, HttpUrl): + file_path, data_path = separate_pure_path(PurePosixPath(src.path or "")) - return path, None + if data_path is None: + return src, None + return ( + HttpUrl(str(file_path).replace(f"/{data_path}", "")), + data_path, + ) + + if isinstance(src, ZipPath): + file_path, data_path = separate_pure_path(PurePosixPath(str(src))) + + if data_path is None: + return src, None + + return ( + ZipPath(str(file_path).replace(f"/{data_path}", "")), + data_path, + ) -def save_tensor(path: Path, tensor: Tensor) -> None: + file_path, data_path = separate_pure_path(PurePosixPath(src)) + return Path(file_path), data_path + + +def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: # TODO: save axis meta data data: NDArray[Any] = tensor.data.to_numpy() @@ -134,23 +218,33 @@ def save_tensor(path: Path, tensor: Tensor) -> None: imwrite(path, data) -def save_sample(path: Union[Path, str, PerMember[Path]], sample: Sample) -> None: - """save a sample to path - - If `path` is a pathlib.Path or a string it must contain `{member_id}` and may contain `{sample_id}`, - which are resolved with the `sample` object. - """ +def save_sample( + path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample +) -> None: + """Save a **sample** to a **path** pattern + or all sample members in the **path** mapping. - if not isinstance(path, collections.abc.Mapping) and "{member_id}" not in str(path): - raise ValueError(f"missing `{{member_id}}` in path {path}") + If **path** is a pathlib.Path or a string and the **sample** has multiple members, + **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). - for m, t in sample.members.items(): - if isinstance(path, collections.abc.Mapping): - p = path[m] + (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. + """ + if not isinstance(path, collections.abc.Mapping): + if len(sample.members) < 2 or any( + m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") + ): + path = {m: path for m in sample.members} else: - p = Path(str(path).format(sample_id=sample.id, member_id=m)) + raise ValueError( + f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." + ) - save_tensor(p, t) + for m, p in path.items(): + t = sample.members[m] + p_formatted = Path( + str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) + ) + save_tensor(p_formatted, t) class _SerializedDatasetStatsEntry( @@ -176,3 +270,27 @@ def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path): def load_dataset_stat(path: Path): seq = _stat_adapter.validate_json(path.read_bytes()) return {e.measure: e.value for e in seq} + + +def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path): + """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive. + Always returns the path to the unzipped source (maybe source itself)""" + local_weights_file = download(source).path + if isinstance(local_weights_file, ZipPath): + # source is inside a zip archive + out_path = folder / local_weights_file.filename + with local_weights_file.open("rb") as src, out_path.open("wb") as dst: + assert not isinstance(src, TextIOWrapper) + copyfileobj(src, dst) + + local_weights_file = out_path + + if zipfile.is_zipfile(local_weights_file): + # source itself is a zipfile + out_path = folder / local_weights_file.with_suffix(".unzipped").name + with zipfile.ZipFile(local_weights_file, "r") as f: + f.extractall(out_path) + + return out_path + else: + return local_weights_file diff --git a/bioimageio/core/model_adapters.py b/bioimageio/core/model_adapters.py new file mode 100644 index 00000000..db92d013 --- /dev/null +++ b/bioimageio/core/model_adapters.py @@ -0,0 +1,22 @@ +"""DEPRECATED""" + +from typing import List + +from .backends._model_adapter import ( + DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, + ModelAdapter, + create_model_adapter, +) + +__all__ = [ + "ModelAdapter", + "create_model_adapter", + "get_weight_formats", +] + + +def get_weight_formats() -> List[str]: + """ + Return list of supported weight types + """ + return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) diff --git a/bioimageio/core/model_adapters/__init__.py b/bioimageio/core/model_adapters/__init__.py deleted file mode 100644 index 01899de9..00000000 --- a/bioimageio/core/model_adapters/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._model_adapter import ModelAdapter, create_model_adapter, get_weight_formats - -__all__ = [ - "ModelAdapter", - "create_model_adapter", - "get_weight_formats", -] diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py deleted file mode 100644 index c918603e..00000000 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ /dev/null @@ -1,177 +0,0 @@ -import warnings -from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Tuple, Union, final - -from bioimageio.spec.model import v0_4, v0_5 - -from ..tensor import Tensor - -WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] - -# Known weight formats in order of priority -# First match wins -DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = ( - "pytorch_state_dict", - "tensorflow_saved_model_bundle", - "torchscript", - "onnx", - "keras_hdf5", -) - - -class ModelAdapter(ABC): - """ - Represents model *without* any preprocessing or postprocessing. - - ``` - from bioimageio.core import load_description - - model = load_description(...) - - # option 1: - adapter = ModelAdapter.create(model) - adapter.forward(...) - adapter.unload() - - # option 2: - with ModelAdapter.create(model) as adapter: - adapter.forward(...) - ``` - """ - - @final - @classmethod - def create( - cls, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - *, - devices: Optional[Sequence[str]] = None, - weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None, - ): - """ - Creates model adapter based on the passed spec - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError( - f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" - ) - - weights = model_description.weights - errors: List[Tuple[WeightsFormat, Exception]] = [] - weight_format_priority_order = ( - DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER - if weight_format_priority_order is None - else weight_format_priority_order - ) - # limit weight formats to the ones present - weight_format_priority_order = [ - w for w in weight_format_priority_order if getattr(weights, w) is not None - ] - - for wf in weight_format_priority_order: - if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: - try: - from ._pytorch_model_adapter import PytorchModelAdapter - - return PytorchModelAdapter( - outputs=model_description.outputs, - weights=weights.pytorch_state_dict, - devices=devices, - ) - except Exception as e: - errors.append((wf, e)) - elif ( - wf == "tensorflow_saved_model_bundle" - and weights.tensorflow_saved_model_bundle is not None - ): - try: - from ._tensorflow_model_adapter import TensorflowModelAdapter - - return TensorflowModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - elif wf == "onnx" and weights.onnx is not None: - try: - from ._onnx_model_adapter import ONNXModelAdapter - - return ONNXModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - elif wf == "torchscript" and weights.torchscript is not None: - try: - from ._torchscript_model_adapter import TorchscriptModelAdapter - - return TorchscriptModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: - # keras can either be installed as a separate package or used as part of tensorflow - # we try to first import the keras model adapter using the separate package and, - # if it is not available, try to load the one using tf - try: - from ._keras_model_adapter import ( - KerasModelAdapter, - keras, # type: ignore - ) - - if keras is None: - from ._tensorflow_model_adapter import KerasModelAdapter - - return KerasModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - - assert errors - if len(weight_format_priority_order) == 1: - assert len(errors) == 1 - raise ValueError( - f"The '{weight_format_priority_order[0]}' model adapter could not be created" - + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n" - ) - - else: - error_list = "\n - ".join( - f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors - ) - raise ValueError( - "None of the weight format specific model adapters could be created" - + f" in this environment. Errors are:\n\n{error_list}.\n\n" - ) - - @final - def load(self, *, devices: Optional[Sequence[str]] = None) -> None: - warnings.warn("Deprecated. ModelAdapter is loaded on initialization") - - @abstractmethod - def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - """ - Run forward pass of model to get model predictions - """ - # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl - - @abstractmethod - def unload(self): - """ - Unload model from any devices, freeing their memory. - The moder adapter should be considered unusable afterwards. - """ - - -def get_weight_formats() -> List[str]: - """ - Return list of supported weight types - """ - return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) - - -create_model_adapter = ModelAdapter.create diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py deleted file mode 100644 index c747de22..00000000 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ /dev/null @@ -1,71 +0,0 @@ -import warnings -from typing import Any, List, Optional, Sequence, Union - -from numpy.typing import NDArray - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download - -from ..digest_spec import get_axes_infos -from ..tensor import Tensor -from ._model_adapter import ModelAdapter - -try: - import onnxruntime as rt -except Exception as e: - rt = None - rt_error = str(e) -else: - rt_error = None - - -class ONNXModelAdapter(ModelAdapter): - def __init__( - self, - *, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - devices: Optional[Sequence[str]] = None, - ): - if rt is None: - raise ImportError(f"failed to import onnxruntime: {rt_error}") - - super().__init__() - self._internal_output_axes = [ - tuple(a.id for a in get_axes_infos(out)) - for out in model_description.outputs - ] - if model_description.weights.onnx is None: - raise ValueError("No ONNX weights specified for {model_description.name}") - - self._session = rt.InferenceSession( - str(download(model_description.weights.onnx.source).path) - ) - onnx_inputs = self._session.get_inputs() # type: ignore - self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore - - if devices is not None: - warnings.warn( - f"Device management is not implemented for onnx yet, ignoring the devices {devices}" - ) - - def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - assert len(input_tensors) == len(self._input_names) - input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors] - result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]] - result = self._session.run( # pyright: ignore[reportUnknownVariableType] - None, dict(zip(self._input_names, input_arrays)) - ) - if isinstance(result, (list, tuple)): - result_seq: Sequence[Optional[NDArray[Any]]] = result - else: - result_seq = [result] # type: ignore - - return [ - None if r is None else Tensor(r, dims=axes) - for r, axes in zip(result_seq, self._internal_output_axes) - ] - - def unload(self) -> None: - warnings.warn( - "Device management is not implemented for onnx yet, cannot unload model" - ) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py deleted file mode 100644 index a5178d74..00000000 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ /dev/null @@ -1,153 +0,0 @@ -import gc -import warnings -from typing import Any, List, Optional, Sequence, Tuple, Union - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download - -from ..axis import AxisId -from ..digest_spec import get_axes_infos, import_callable -from ..tensor import Tensor -from ._model_adapter import ModelAdapter - -try: - import torch -except Exception as e: - torch = None - torch_error = str(e) -else: - torch_error = None - - -class PytorchModelAdapter(ModelAdapter): - def __init__( - self, - *, - outputs: Union[ - Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr] - ], - weights: Union[ - v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr - ], - devices: Optional[Sequence[str]] = None, - ): - if torch is None: - raise ImportError(f"failed to import torch: {torch_error}") - - super().__init__() - self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs] - self._network = self.get_network(weights) - self._devices = self.get_devices(devices) - self._network = self._network.to(self._devices[0]) - - self._primary_device = self._devices[0] - state: Any = torch.load( - download(weights).path, - map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType] - ) - self._network.load_state_dict(state) - - self._network = self._network.eval() - - def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - if torch is None: - raise ImportError("torch") - with torch.no_grad(): - tensors = [ - None if ipt is None else torch.from_numpy(ipt.data.data) - for ipt in input_tensors - ] - tensors = [ - ( - None - if t is None - else t.to( - self._primary_device # pyright: ignore[reportUnknownArgumentType] - ) - ) - for t in tensors - ] - result: Union[Tuple[Any, ...], List[Any], Any] - result = self._network( # pyright: ignore[reportUnknownVariableType] - *tensors - ) - if not isinstance(result, (tuple, list)): - result = [result] - - result = [ - ( - None - if r is None - else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r - ) - for r in result # pyright: ignore[reportUnknownVariableType] - ] - if len(result) > len(self.output_dims): - raise ValueError( - f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}" - ) - - return [ - None if r is None else Tensor(r, dims=out) - for r, out in zip(result, self.output_dims) - ] - - def unload(self) -> None: - del self._network - _ = gc.collect() # deallocate memory - assert torch is not None - torch.cuda.empty_cache() # release reserved memory - - @staticmethod - def get_network( # pyright: ignore[reportUnknownParameterType] - weight_spec: Union[ - v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr - ], - ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm] - if torch is None: - raise ImportError("torch") - arch = import_callable( - weight_spec.architecture, - sha256=( - weight_spec.architecture_sha256 - if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) - else weight_spec.sha256 - ), - ) - model_kwargs = ( - weight_spec.kwargs - if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) - else weight_spec.architecture.kwargs - ) - network = arch(**model_kwargs) - if not isinstance(network, torch.nn.Module): - raise ValueError( - f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module" - ) - - return network - - @staticmethod - def get_devices( # pyright: ignore[reportUnknownParameterType] - devices: Optional[Sequence[str]] = None, - ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm] - if torch is None: - raise ImportError("torch") - if not devices: - torch_devices = [ - ( - torch.device("cuda") - if torch.cuda.is_available() - else torch.device("cpu") - ) - ] - else: - torch_devices = [torch.device(d) for d in devices] - - if len(torch_devices) > 1: - warnings.warn( - f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" - ) - torch_devices = torch_devices[:1] - - return torch_devices diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py deleted file mode 100644 index cfb264f0..00000000 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ /dev/null @@ -1,275 +0,0 @@ -import zipfile -from typing import List, Literal, Optional, Sequence, Union - -import numpy as np -from loguru import logger - -from bioimageio.spec.common import FileSource -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download - -from ..digest_spec import get_axes_infos -from ..tensor import Tensor -from ._model_adapter import ModelAdapter - -try: - import tensorflow as tf # pyright: ignore[reportMissingImports] -except Exception as e: - tf = None - tf_error = str(e) -else: - tf_error = None - - -class TensorflowModelAdapterBase(ModelAdapter): - weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] - - def __init__( - self, - *, - devices: Optional[Sequence[str]] = None, - weights: Union[ - v0_4.KerasHdf5WeightsDescr, - v0_4.TensorflowSavedModelBundleWeightsDescr, - v0_5.KerasHdf5WeightsDescr, - v0_5.TensorflowSavedModelBundleWeightsDescr, - ], - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - ): - if tf is None: - raise ImportError(f"failed to import tensorflow: {tf_error}") - - super().__init__() - self.model_description = model_description - tf_version = v0_5.Version( - tf.__version__ # pyright: ignore[reportUnknownArgumentType] - ) - model_tf_version = weights.tensorflow_version - if model_tf_version is None: - logger.warning( - "The model does not specify the tensorflow version." - + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." - ) - elif model_tf_version > tf_version: - logger.warning( - f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." - ) - elif (model_tf_version.major, model_tf_version.minor) != ( - tf_version.major, - tf_version.minor, - ): - logger.warning( - "The tensorflow version specified by the model does not match the installed: " - + f"{model_tf_version} != {tf_version}." - ) - - self.use_keras_api = ( - tf_version.major > 1 - or self.weight_format == KerasModelAdapter.weight_format - ) - - # TODO tf device management - if devices is not None: - logger.warning( - f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" - ) - - weight_file = self.require_unzipped(weights.source) - self._network = self._get_network(weight_file) - self._internal_output_axes = [ - tuple(a.id for a in get_axes_infos(out)) - for out in model_description.outputs - ] - - def require_unzipped(self, weight_file: FileSource): - loacl_weights_file = download(weight_file).path - if zipfile.is_zipfile(loacl_weights_file): - out_path = loacl_weights_file.with_suffix(".unzipped") - with zipfile.ZipFile(loacl_weights_file, "r") as f: - f.extractall(out_path) - - return out_path - else: - return loacl_weights_file - - def _get_network( # pyright: ignore[reportUnknownParameterType] - self, weight_file: FileSource - ): - weight_file = self.require_unzipped(weight_file) - assert tf is not None - if self.use_keras_api: - try: - return tf.keras.layers.TFSMLayer( - weight_file, call_endpoint="serve" - ) # pyright: ignore[reportUnknownVariableType] - except Exception as e: - try: - return tf.keras.layers.TFSMLayer( - weight_file, call_endpoint="serving_default" - ) # pyright: ignore[reportUnknownVariableType] - except Exception as ee: - logger.opt(exception=ee).info( - "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" - ) - raise e - else: - # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model - return str(weight_file) - - # TODO currently we relaod the model every time. it would be better to keep the graph and session - # alive in between of forward passes (but then the sessions need to be properly opened / closed) - def _forward_tf( # pyright: ignore[reportUnknownParameterType] - self, *input_tensors: Optional[Tensor] - ): - assert tf is not None - input_keys = [ - ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id - for ipt in self.model_description.inputs - ] - output_keys = [ - out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id - for out in self.model_description.outputs - ] - # TODO read from spec - tag = ( # pyright: ignore[reportUnknownVariableType] - tf.saved_model.tag_constants.SERVING - ) - signature_key = ( # pyright: ignore[reportUnknownVariableType] - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - ) - - graph = tf.Graph() # pyright: ignore[reportUnknownVariableType] - with graph.as_default(): - with tf.Session( - graph=graph - ) as sess: # pyright: ignore[reportUnknownVariableType] - # load the model and the signature - graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] - sess, [tag], self._network - ) - signature = ( # pyright: ignore[reportUnknownVariableType] - graph_def.signature_def - ) - - # get the tensors into the graph - in_names = [ # pyright: ignore[reportUnknownVariableType] - signature[signature_key].inputs[key].name for key in input_keys - ] - out_names = [ # pyright: ignore[reportUnknownVariableType] - signature[signature_key].outputs[key].name for key in output_keys - ] - in_tensors = [ # pyright: ignore[reportUnknownVariableType] - graph.get_tensor_by_name(name) - for name in in_names # pyright: ignore[reportUnknownVariableType] - ] - out_tensors = [ # pyright: ignore[reportUnknownVariableType] - graph.get_tensor_by_name(name) - for name in out_names # pyright: ignore[reportUnknownVariableType] - ] - - # run prediction - res = sess.run( # pyright: ignore[reportUnknownVariableType] - dict( - zip( - out_names, # pyright: ignore[reportUnknownArgumentType] - out_tensors, # pyright: ignore[reportUnknownArgumentType] - ) - ), - dict( - zip( - in_tensors, # pyright: ignore[reportUnknownArgumentType] - input_tensors, - ) - ), - ) - # from dict to list of tensors - res = [ # pyright: ignore[reportUnknownVariableType] - res[out] - for out in out_names # pyright: ignore[reportUnknownVariableType] - ] - - return res # pyright: ignore[reportUnknownVariableType] - - def _forward_keras( # pyright: ignore[reportUnknownParameterType] - self, *input_tensors: Optional[Tensor] - ): - assert self.use_keras_api - assert not isinstance(self._network, str) - assert tf is not None - tf_tensor = [ # pyright: ignore[reportUnknownVariableType] - None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors - ] - - result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] - - assert isinstance(result, dict) - - # TODO: Use RDF's `outputs[i].id` here - result = list(result.values()) - - return [ # pyright: ignore[reportUnknownVariableType] - (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) - for r in result # pyright: ignore[reportUnknownVariableType] - ] - - def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - data = [None if ipt is None else ipt.data for ipt in input_tensors] - if self.use_keras_api: - result = self._forward_keras( # pyright: ignore[reportUnknownVariableType] - *data - ) - else: - result = self._forward_tf( # pyright: ignore[reportUnknownVariableType] - *data - ) - - return [ - None if r is None else Tensor(r, dims=axes) - for r, axes in zip( # pyright: ignore[reportUnknownVariableType] - result, # pyright: ignore[reportUnknownArgumentType] - self._internal_output_axes, - ) - ] - - def unload(self) -> None: - logger.warning( - "Device management is not implemented for keras yet, cannot unload model" - ) - - -class TensorflowModelAdapter(TensorflowModelAdapterBase): - weight_format = "tensorflow_saved_model_bundle" - - def __init__( - self, - *, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - devices: Optional[Sequence[str]] = None, - ): - if model_description.weights.tensorflow_saved_model_bundle is None: - raise ValueError("missing tensorflow_saved_model_bundle weights") - - super().__init__( - devices=devices, - weights=model_description.weights.tensorflow_saved_model_bundle, - model_description=model_description, - ) - - -class KerasModelAdapter(TensorflowModelAdapterBase): - weight_format = "keras_hdf5" - - def __init__( - self, - *, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - devices: Optional[Sequence[str]] = None, - ): - if model_description.weights.keras_hdf5 is None: - raise ValueError("missing keras_hdf5 weights") - - super().__init__( - model_description=model_description, - devices=devices, - weights=model_description.weights.keras_hdf5, - ) diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py deleted file mode 100644 index 0e9f3aef..00000000 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ /dev/null @@ -1,96 +0,0 @@ -import gc -import warnings -from typing import Any, List, Optional, Sequence, Tuple, Union - -import numpy as np -from numpy.typing import NDArray - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download - -from ..digest_spec import get_axes_infos -from ..tensor import Tensor -from ._model_adapter import ModelAdapter - -try: - import torch -except Exception as e: - torch = None - torch_error = str(e) -else: - torch_error = None - - -class TorchscriptModelAdapter(ModelAdapter): - def __init__( - self, - *, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - devices: Optional[Sequence[str]] = None, - ): - if torch is None: - raise ImportError(f"failed to import torch: {torch_error}") - - super().__init__() - if model_description.weights.torchscript is None: - raise ValueError( - f"No torchscript weights found for model {model_description.name}" - ) - - weight_path = download(model_description.weights.torchscript.source).path - if devices is None: - self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] - else: - self.devices = [torch.device(d) for d in devices] - - if len(self.devices) > 1: - warnings.warn( - "Multiple devices for single torchscript model not yet implemented" - ) - - self._model = torch.jit.load(weight_path) - self._model.to(self.devices[0]) - self._model = self._model.eval() - self._internal_output_axes = [ - tuple(a.id for a in get_axes_infos(out)) - for out in model_description.outputs - ] - - def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: - assert torch is not None - with torch.no_grad(): - torch_tensor = [ - None if b is None else torch.from_numpy(b.data.data).to(self.devices[0]) - for b in batch - ] - _result: Union[ # pyright: ignore[reportUnknownVariableType] - Tuple[Optional[NDArray[Any]], ...], - List[Optional[NDArray[Any]]], - Optional[NDArray[Any]], - ] = self._model.forward(*torch_tensor) - if isinstance(_result, (tuple, list)): - result: Sequence[Optional[NDArray[Any]]] = _result - else: - result = [_result] - - result = [ - ( - None - if r is None - else r.cpu().numpy() if not isinstance(r, np.ndarray) else r - ) - for r in result - ] - - assert len(result) == len(self._internal_output_axes) - return [ - None if r is None else Tensor(r, dims=axes) - for r, axes in zip(result, self._internal_output_axes) - ] - - def unload(self) -> None: - assert torch is not None - self._devices = None - del self._model - _ = gc.collect() # deallocate memory - torch.cuda.empty_cache() # release reserved memory diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index 27a4129c..9fe5ce12 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -1,7 +1,6 @@ import collections.abc from pathlib import Path from typing import ( - Any, Hashable, Iterable, Iterator, @@ -11,9 +10,7 @@ Union, ) -import xarray as xr from loguru import logger -from numpy.typing import NDArray from tqdm import tqdm from bioimageio.spec import load_description @@ -22,11 +19,10 @@ from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline from .axis import AxisId -from .common import MemberId, PerMember -from .digest_spec import create_sample_for_model +from .common import BlocksizeParameter, MemberId, PerMember +from .digest_spec import TensorSource, create_sample_for_model, get_member_id from .io import save_sample from .sample import Sample -from .tensor import Tensor def predict( @@ -34,14 +30,9 @@ def predict( model: Union[ PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline ], - inputs: Union[Sample, PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]], + inputs: Union[Sample, PerMember[TensorSource], TensorSource], sample_id: Hashable = "sample", - blocksize_parameter: Optional[ - Union[ - v0_5.ParameterizedSize_N, - Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], - ] - ] = None, + blocksize_parameter: Optional[BlocksizeParameter] = None, input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, skip_preprocessing: bool = False, skip_postprocessing: bool = False, @@ -50,29 +41,31 @@ def predict( """Run prediction for a single set of input(s) with a bioimage.io model Args: - model: model to predict with. + model: Model to predict with. May be given as RDF source, model description or prediction pipeline. inputs: the input sample or the named input(s) for this model as a dictionary sample_id: the sample id. - blocksize_parameter: (optional) tile the input into blocks parametrized by - blocksize according to any parametrized axis sizes defined in the model RDF. - Note: For a predetermined, fixed block shape use `input_block_shape` - input_block_shape: (optional) tile the input sample tensors into blocks. - Note: For a parameterized block shape, not dealing with the exact block shape, - use `blocksize_parameter`. - skip_preprocessing: flag to skip the model's preprocessing - skip_postprocessing: flag to skip the model's postprocessing - save_output_path: A path with `{member_id}` `{sample_id}` in it - to save the output to. + The **sample_id** is used to format **save_output_path** + and to distinguish sample specific log messages. + blocksize_parameter: (optional) Tile the input into blocks parametrized by + **blocksize_parameter** according to any parametrized axis sizes defined + by the **model**. + See `bioimageio.spec.model.v0_5.ParameterizedSize` for details. + Note: For a predetermined, fixed block shape use **input_block_shape**. + input_block_shape: (optional) Tile the input sample tensors into blocks. + Note: Use **blocksize_parameter** for a parameterized block shape to + run prediction independent of the exact block shape. + skip_preprocessing: Flag to skip the model's preprocessing. + skip_postprocessing: Flag to skip the model's postprocessing. + save_output_path: A path with to save the output to. M + Must contain: + - `{output_id}` (or `{member_id}`) if the model has multiple output tensors + May contain: + - `{sample_id}` to avoid overwriting recurrent calls """ - if save_output_path is not None: - if "{member_id}" not in str(save_output_path): - raise ValueError( - f"Missing `{{member_id}}` in save_output_path={save_output_path}" - ) - if isinstance(model, PredictionPipeline): pp = model + model = pp.model_description else: if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): loaded = load_description(model) @@ -82,6 +75,18 @@ def predict( pp = create_prediction_pipeline(model) + if save_output_path is not None: + if ( + "{output_id}" not in str(save_output_path) + and "{member_id}" not in str(save_output_path) + and len(model.outputs) > 1 + ): + raise ValueError( + f"Missing `{{output_id}}` in save_output_path={save_output_path} to " + + "distinguish model outputs " + + str([get_member_id(d) for d in model.outputs]) + ) + if isinstance(inputs, Sample): sample = inputs else: @@ -127,7 +132,7 @@ def predict_many( model: Union[ PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline ], - inputs: Iterable[PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]], + inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]], sample_id: str = "sample{i:03}", blocksize_parameter: Optional[ Union[ @@ -142,31 +147,27 @@ def predict_many( """Run prediction for a multiple sets of inputs with a bioimage.io model Args: - model: model to predict with. + model: Model to predict with. May be given as RDF source, model description or prediction pipeline. inputs: An iterable of the named input(s) for this model as a dictionary. - sample_id: the sample id. + sample_id: The sample id. note: `{i}` will be formatted as the i-th sample. - If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended. - blocksize_parameter: (optional) tile the input into blocks parametrized by - blocksize according to any parametrized axis sizes defined in the model RDF - skip_preprocessing: flag to skip the model's preprocessing - skip_postprocessing: flag to skip the model's postprocessing - save_output_path: A path with `{member_id}` `{sample_id}` in it - to save the output to. + If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}` + is appended. + blocksize_parameter: (optional) Tile the input into blocks parametrized by + blocksize according to any parametrized axis sizes defined in the model RDF. + skip_preprocessing: Flag to skip the model's preprocessing. + skip_postprocessing: Flag to skip the model's postprocessing. + save_output_path: A path to save the output to. + Must contain: + - `{sample_id}` to differentiate predicted samples + - `{output_id}` (or `{member_id}`) if the model has multiple outputs """ - if save_output_path is not None: - if "{member_id}" not in str(save_output_path): - raise ValueError( - f"Missing `{{member_id}}` in save_output_path={save_output_path}" - ) - - if not isinstance(inputs, collections.abc.Mapping) and "{sample_id}" not in str( - save_output_path - ): - raise ValueError( - f"Missing `{{sample_id}}` in save_output_path={save_output_path}" - ) + if save_output_path is not None and "{sample_id}" not in str(save_output_path): + raise ValueError( + f"Missing `{{sample_id}}` in save_output_path={save_output_path}" + + " to differentiate predicted samples." + ) if isinstance(model, PredictionPipeline): pp = model @@ -180,7 +181,6 @@ def predict_many( pp = create_prediction_pipeline(model) if not isinstance(inputs, collections.abc.Mapping): - sample_id = str(sample_id) if "{i}" not in sample_id and "{i:" not in sample_id: sample_id += "{i:03}" diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index eecf47b1..e504bf07 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -16,7 +16,11 @@ import xarray as xr from typing_extensions import Self, assert_never +from bioimageio.core.digest_spec import get_member_id from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import ( + _convert_proc, # pyright: ignore [reportPrivateUsage] +) from ._op_base import BlockedOperator, Operator from .axis import AxisId, PerAxis @@ -51,7 +55,7 @@ def _convert_axis_ids( if mode == "per_sample": ret = [] elif mode == "per_dataset": - ret = [AxisId("b")] + ret = [v0_5.BATCH_AXIS_ID] else: assert_never(mode) @@ -299,9 +303,15 @@ def from_proc_descr( member_id: MemberId, ) -> Self: kwargs = descr.kwargs - if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): + if isinstance(kwargs, v0_5.ScaleLinearKwargs): + axis = None + elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): axis = kwargs.axis - elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): + elif isinstance(kwargs, v0_4.ScaleLinearKwargs): + if kwargs.axes is not None: + raise NotImplementedError( + "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." + ) axis = None else: assert_never(kwargs) @@ -605,7 +615,7 @@ def from_proc_descr( if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): dims = None elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): - dims = (descr.kwargs.axis,) + dims = (AxisId(descr.kwargs.axis),) else: assert_never(descr.kwargs) @@ -657,34 +667,53 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: ] -def get_proc_class(proc_spec: ProcDescr): - if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): - return Binarize - elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): - return Clip - elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): - return EnsureDtype - elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): - return FixedZeroMeanUnitVariance - elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): - return ScaleLinear +def get_proc( + proc_descr: ProcDescr, + tensor_descr: Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ], +) -> Processing: + member_id = get_member_id(tensor_descr) + + if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): + return Binarize.from_proc_descr(proc_descr, member_id) + elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): + return Clip.from_proc_descr(proc_descr, member_id) + elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): + return EnsureDtype.from_proc_descr(proc_descr, member_id) + elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): + return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) + elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): + return ScaleLinear.from_proc_descr(proc_descr, member_id) elif isinstance( - proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) + proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) ): - return ScaleMeanVariance - elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): - return ScaleRange - elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): - return Sigmoid + return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) + elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): + return ScaleRange.from_proc_descr(proc_descr, member_id) + elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): + return Sigmoid.from_proc_descr(proc_descr, member_id) elif ( - isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) - and proc_spec.kwargs.mode == "fixed" + isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) + and proc_descr.kwargs.mode == "fixed" ): - return FixedZeroMeanUnitVariance + if not isinstance( + tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) + ): + raise TypeError( + "Expected v0_4 tensor description for v0_4 processing description" + ) + + v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) + assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) + return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) elif isinstance( - proc_spec, + proc_descr, (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), ): - return ZeroMeanUnitVariance + return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) else: - assert_never(proc_spec) + assert_never(proc_descr) diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index b9afb711..1ada58b2 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -1,4 +1,5 @@ from typing import ( + Callable, Iterable, List, Mapping, @@ -11,15 +12,15 @@ from typing_extensions import assert_never +from bioimageio.core.digest_spec import get_member_id from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 -from bioimageio.spec.model.v0_5 import TensorId -from .digest_spec import get_member_ids from .proc_ops import ( AddKnownDatasetStats, + EnsureDtype, Processing, UpdateStats, - get_proc_class, + get_proc, ) from .sample import Sample from .stat_calculators import StatsCalculator @@ -45,6 +46,11 @@ class PreAndPostprocessing(NamedTuple): post: List[Processing] +class _ProcessingCallables(NamedTuple): + pre: Callable[[Sample], None] + post: Callable[[Sample], None] + + class _SetupProcessing(NamedTuple): pre: List[Processing] post: List[Processing] @@ -52,6 +58,34 @@ class _SetupProcessing(NamedTuple): post_measures: Set[Measure] +class _ApplyProcs: + def __init__(self, procs: Sequence[Processing]): + super().__init__() + self._procs = procs + + def __call__(self, sample: Sample) -> None: + for op in self._procs: + op(sample) + + +def get_pre_and_postprocessing( + model: AnyModelDescr, + *, + dataset_for_initial_statistics: Iterable[Sample], + keep_updating_initial_dataset_stats: bool = False, + fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, +) -> _ProcessingCallables: + """Creates callables to apply pre- and postprocessing in-place to a sample""" + + setup = setup_pre_and_postprocessing( + model=model, + dataset_for_initial_statistics=dataset_for_initial_statistics, + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, + fixed_dataset_stats=fixed_dataset_stats, + ) + return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post)) + + def setup_pre_and_postprocessing( model: AnyModelDescr, dataset_for_initial_statistics: Iterable[Sample], @@ -60,7 +94,7 @@ def setup_pre_and_postprocessing( ) -> PreAndPostprocessing: """ Get pre- and postprocessing operators for a `model` description. - userd in `bioimageio.core.create_prediction_pipeline""" + Used in `bioimageio.core.create_prediction_pipeline""" prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) missing_dataset_stats = { @@ -136,65 +170,63 @@ def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures ) +def _prepare_procs( + tensor_descrs: Union[ + Sequence[v0_4.InputTensorDescr], + Sequence[v0_5.InputTensorDescr], + Sequence[v0_4.OutputTensorDescr], + Sequence[v0_5.OutputTensorDescr], + ], +) -> List[Processing]: + procs: List[Processing] = [] + for t_descr in tensor_descrs: + if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): + member_id = get_member_id(t_descr) + procs.append( + EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type) + ) + + if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): + for proc_d in t_descr.preprocessing: + procs.append(get_proc(proc_d, t_descr)) + elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)): + for proc_d in t_descr.postprocessing: + procs.append(get_proc(proc_d, t_descr)) + else: + assert_never(t_descr) + + if isinstance( + t_descr, + (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)), + ): + if len(procs) == 1: + # remove initial ensure_dtype if there are no other proccessing steps + assert isinstance(procs[0], EnsureDtype) + procs = [] + + # ensure 0.4 models get float32 input + # which has been the implicit assumption for 0.4 + member_id = get_member_id(t_descr) + procs.append( + EnsureDtype(input=member_id, output=member_id, dtype="float32") + ) + + return procs + + def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: - pre_measures: Set[Measure] = set() - post_measures: Set[Measure] = set() - - input_ids = set(get_member_ids(model.inputs)) - output_ids = set(get_member_ids(model.outputs)) - - def prepare_procs(tensor_descrs: Sequence[TensorDescr]): - procs: List[Processing] = [] - for t_descr in tensor_descrs: - if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): - proc_descrs: List[ - Union[ - v0_4.PreprocessingDescr, - v0_5.PreprocessingDescr, - v0_4.PostprocessingDescr, - v0_5.PostprocessingDescr, - ] - ] = list(t_descr.preprocessing) - elif isinstance( - t_descr, - (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), - ): - proc_descrs = list(t_descr.postprocessing) - else: - assert_never(t_descr) - - if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): - ensure_dtype = v0_5.EnsureDtypeDescr( - kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type) - ) - if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs: - proc_descrs.insert(0, ensure_dtype) - - proc_descrs.append(ensure_dtype) - - for proc_d in proc_descrs: - proc_class = get_proc_class(proc_d) - member_id = ( - TensorId(str(t_descr.name)) - if isinstance(t_descr, v0_4.TensorDescrBase) - else t_descr.id - ) - req = proc_class.from_proc_descr( - proc_d, member_id # pyright: ignore[reportArgumentType] - ) - for m in req.required_measures: - if m.member_id in input_ids: - pre_measures.add(m) - elif m.member_id in output_ids: - post_measures.add(m) - else: - raise ValueError("When to raise ") - procs.append(req) - return procs + if isinstance(model, v0_4.ModelDescr): + pre = _prepare_procs(model.inputs) + post = _prepare_procs(model.outputs) + elif isinstance(model, v0_5.ModelDescr): + pre = _prepare_procs(model.inputs) + post = _prepare_procs(model.outputs) + else: + assert_never(model) return _SetupProcessing( - pre=prepare_procs(model.inputs), - post=prepare_procs(model.outputs), - pre_measures=pre_measures, - post_measures=post_measures, + pre=pre, + post=post, + pre_measures={m for proc in pre for m in proc.required_measures}, + post_measures={m for proc in post for m in proc.required_measures}, ) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 0620282d..0d4c3724 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from math import ceil, floor from typing import ( + Any, Callable, Dict, Generic, @@ -14,6 +15,7 @@ ) import numpy as np +from numpy.typing import NDArray from typing_extensions import Self from .axis import AxisId, PerAxis @@ -42,21 +44,31 @@ @dataclass class Sample: - """A dataset sample""" + """A dataset sample. + + A `Sample` has `members`, which allows to combine multiple tensors into a single + sample. + For example a `Sample` from a dataset with masked images may contain a + `MemberId("raw")` and `MemberId("mask")` image. + """ members: Dict[MemberId, Tensor] - """the sample's tensors""" + """The sample's tensors""" stat: Stat - """sample and dataset statistics""" + """Sample and dataset statistics""" id: SampleId - """identifier within the sample's dataset""" + """Identifies the `Sample` within the dataset -- typically a number or a string.""" @property def shape(self) -> PerMember[PerAxis[int]]: return {tid: t.sizes for tid, t in self.members.items()} + def as_arrays(self) -> Dict[str, NDArray[Any]]: + """Return sample as dictionary of arrays.""" + return {str(m): t.data.to_numpy() for m, t in self.members.items()} + def split_into_blocks( self, block_shapes: PerMember[PerAxis[int]], diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 41233a5b..515fe843 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -1,6 +1,6 @@ from __future__ import annotations -import collections.abc +import collections import warnings from itertools import product from typing import ( @@ -26,6 +26,8 @@ from numpy.typing import NDArray from typing_extensions import assert_never +from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID + from .axis import AxisId, PerAxis from .common import MemberId from .sample import Sample @@ -47,7 +49,7 @@ from .tensor import Tensor try: - import crick + import crick # pyright: ignore[reportMissingImports] except Exception: crick = None @@ -120,7 +122,7 @@ class MeanVarStdCalculator: def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): super().__init__() - self._axes = None if axes is None else tuple(axes) + self._axes = None if axes is None else tuple(map(AxisId, axes)) self._member_id = member_id self._n: int = 0 self._mean: Optional[Tensor] = None @@ -137,7 +139,15 @@ def compute( else: n = int(np.prod([tensor.sizes[d] for d in self._axes])) - var = xr.dot(c, c, dims=self._axes) / n + if xr.__version__.startswith("2023"): + var = ( # pyright: ignore[reportUnknownVariableType] + xr.dot(c, c, dims=self._axes) / n + ) + else: + var = ( # pyright: ignore[reportUnknownVariableType] + xr.dot(c, c, dim=self._axes) / n + ) + assert isinstance(var, xr.DataArray) std = np.sqrt(var) assert isinstance(std, xr.DataArray) @@ -152,6 +162,9 @@ def compute( } def update(self, sample: Sample): + if self._axes is not None and BATCH_AXIS_ID not in self._axes: + return + tensor = sample.members[self._member_id].astype("float64", copy=False) mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == "float64" @@ -178,12 +191,16 @@ def update(self, sample: Sample): def finalize( self, ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: - if self._mean is None: + if ( + self._axes is not None + and BATCH_AXIS_ID not in self._axes + or self._mean is None + ): return {} else: assert self._m2 is not None var = self._m2 / self._n - sqrt = np.sqrt(var) + sqrt = var**0.5 if isinstance(sqrt, (int, float)): # var and mean are scalar tensors, let's keep it consistent sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) @@ -306,7 +323,8 @@ def _initialize(self, tensor_sizes: PerAxis[int]): out_sizes[d] = s self._dims, self._shape = zip(*out_sizes.items()) - d = int(np.prod(self._shape[1:])) # type: ignore + assert self._shape is not None + d = int(np.prod(self._shape[1:])) self._digest = [TDigest() for _ in range(d)] self._indices = product(*map(range, self._shape[1:])) diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 57148058..cb3b3da9 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -186,9 +186,20 @@ def dims(self): # TODO: rename to `axes`? return cast(Tuple[AxisId, ...], self._data.dims) @property - def tagged_shape(self): - """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" - return self.sizes + def dtype(self) -> DTypeStr: + dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] + assert dt in get_args(DTypeStr) + return dt # pyright: ignore[reportReturnType] + + @property + def ndim(self): + """Number of tensor dimensions.""" + return self._data.ndim + + @property + def shape(self): + """Tuple of tensor axes lengths""" + return self._data.shape @property def shape_tuple(self): @@ -203,26 +214,21 @@ def size(self): """ return self._data.size - def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - """Reduce this Tensor's data by applying sum along some dimension(s).""" - return self.__class__.from_xarray(self._data.sum(dim=dim)) - - @property - def ndim(self): - """Number of tensor dimensions.""" - return self._data.ndim - - @property - def dtype(self) -> DTypeStr: - dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] - assert dt in get_args(DTypeStr) - return dt # pyright: ignore[reportReturnType] - @property def sizes(self): """Ordered, immutable mapping from axis ids to axis lengths.""" return cast(Mapping[AxisId, int], self.data.sizes) + @property + def tagged_shape(self): + """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" + return self.sizes + + def argmax(self) -> Mapping[AxisId, int]: + ret = self._data.argmax(...) + assert isinstance(ret, dict) + return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} + def astype(self, dtype: DTypeStr, *, copy: bool = False): """Return tensor cast to `dtype` @@ -282,14 +288,23 @@ def crop_to( def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) - def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - return self.__class__.from_xarray(self._data.mean(dim=dim)) + def item( + self, + key: Union[ + None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] + ] = None, + ): + """Copy a tensor element to a standard Python scalar and return it.""" + if key is None: + ret = self._data.item() + else: + ret = self[key]._data.item() - def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - return self.__class__.from_xarray(self._data.std(dim=dim)) + assert isinstance(ret, (bool, float, int)) + return ret - def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - return self.__class__.from_xarray(self._data.var(dim=dim)) + def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + return self.__class__.from_xarray(self._data.mean(dim=dim)) def pad( self, @@ -405,6 +420,13 @@ def resize_to( return tensor + def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + return self.__class__.from_xarray(self._data.std(dim=dim)) + + def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + """Reduce this Tensor's data by applying sum along some dimension(s).""" + return self.__class__.from_xarray(self._data.sum(dim=dim)) + def transpose( self, axes: Sequence[AxisId], @@ -423,6 +445,9 @@ def transpose( # transpose to the correct axis order return self.__class__.from_xarray(array.transpose(*axes)) + def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + return self.__class__.from_xarray(self._data.var(dim=dim)) + @classmethod def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): ndim = array.ndim diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 84e94d38..695f0172 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -2,6 +2,8 @@ import sys from pathlib import Path +from ._compare import compare as compare + if sys.version_info < (3, 9): def files(package_name: str): diff --git a/bioimageio/core/utils/_compare.py b/bioimageio/core/utils/_compare.py new file mode 100644 index 00000000..b8c673a9 --- /dev/null +++ b/bioimageio/core/utils/_compare.py @@ -0,0 +1,30 @@ +from difflib import HtmlDiff, unified_diff +from typing import Sequence + +from typing_extensions import Literal, assert_never + + +def compare( + a: Sequence[str], + b: Sequence[str], + name_a: str = "source", + name_b: str = "updated", + *, + diff_format: Literal["unified", "html"], +): + if diff_format == "html": + diff = HtmlDiff().make_file(a, b, name_a, name_b, charset="utf-8") + elif diff_format == "unified": + diff = "\n".join( + unified_diff( + a, + b, + name_a, + name_b, + lineterm="", + ) + ) + else: + assert_never(diff_format) + + return diff diff --git a/bioimageio/core/utils/testing.py b/bioimageio/core/utils/testing.py deleted file mode 100644 index acd65d95..00000000 --- a/bioimageio/core/utils/testing.py +++ /dev/null @@ -1,28 +0,0 @@ -# TODO: move to tests/ -from functools import wraps -from typing import Any, Protocol, Type - - -class test_func(Protocol): - def __call__(*args: Any, **kwargs: Any): ... - - -def skip_on(exception: Type[Exception], reason: str): - """adapted from https://stackoverflow.com/a/63522579""" - import pytest - - # Func below is the real decorator and will receive the test function as param - def decorator_func(f: test_func): - @wraps(f) - def wrapper(*args: Any, **kwargs: Any): - try: - # Try to run the test - return f(*args, **kwargs) - except exception: - # If exception of given type happens - # just swallow it and raise pytest.Skip with given reason - pytest.skip(reason) - - return wrapper - - return decorator_func diff --git a/bioimageio/core/weight_converter/__init__.py b/bioimageio/core/weight_converter/__init__.py deleted file mode 100644 index 5f1674c9..00000000 --- a/bioimageio/core/weight_converter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""coming soon""" diff --git a/bioimageio/core/weight_converter/keras/__init__.py b/bioimageio/core/weight_converter/keras/__init__.py deleted file mode 100644 index 195b42b8..00000000 --- a/bioimageio/core/weight_converter/keras/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: update keras weight converters diff --git a/bioimageio/core/weight_converter/keras/_tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py deleted file mode 100644 index c901f458..00000000 --- a/bioimageio/core/weight_converter/keras/_tensorflow.py +++ /dev/null @@ -1,151 +0,0 @@ -# type: ignore # TODO: type -import os -import shutil -from pathlib import Path -from typing import no_type_check -from zipfile import ZipFile - -try: - import tensorflow.saved_model -except Exception: - tensorflow = None - -from bioimageio.spec._internal.io_utils import download -from bioimageio.spec.model.v0_5 import ModelDescr - - -def _zip_model_bundle(model_bundle_folder: Path): - zipped_model_bundle = model_bundle_folder.with_suffix(".zip") - - with ZipFile(zipped_model_bundle, "w") as zip_obj: - for root, _, files in os.walk(model_bundle_folder): - for filename in files: - src = os.path.join(root, filename) - zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) - - try: - shutil.rmtree(model_bundle_folder) - except Exception: - print("TensorFlow bundled model was not removed after compression") - - return zipped_model_bundle - - -# adapted from -# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 -def _convert_tf1( - keras_weight_path: Path, - output_path: Path, - input_name: str, - output_name: str, - zip_weights: bool, -): - try: - # try to build the tf model with the keras import from tensorflow - from bioimageio.core.weight_converter.keras._tensorflow import ( - keras, # type: ignore - ) - - except Exception: - # if the above fails try to export with the standalone keras - import keras - - @no_type_check - def build_tf_model(): - keras_model = keras.models.load_model(keras_weight_path) - assert tensorflow is not None - builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) - signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( - inputs={input_name: keras_model.input}, - outputs={output_name: keras_model.output}, - ) - - signature_def_map = { - tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature - } - - builder.add_meta_graph_and_variables( - keras.backend.get_session(), - [tensorflow.saved_model.tag_constants.SERVING], - signature_def_map=signature_def_map, - ) - builder.save() - - build_tf_model() - - if zip_weights: - output_path = _zip_model_bundle(output_path) - print("TensorFlow model exported to", output_path) - - return 0 - - -def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): - try: - # try to build the tf model with the keras import from tensorflow - from bioimageio.core.weight_converter.keras._tensorflow import keras - except Exception: - # if the above fails try to export with the standalone keras - import keras - - model = keras.models.load_model(keras_weight_path) - keras.models.save_model(model, output_path) - - if zip_weights: - output_path = _zip_model_bundle(output_path) - print("TensorFlow model exported to", output_path) - - return 0 - - -def convert_weights_to_tensorflow_saved_model_bundle( - model: ModelDescr, output_path: Path -): - """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. - - Adapted from - https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py - - Args: - model: The bioimageio model description - output_path: where to save the tensorflow weights. This path must not exist yet. - """ - assert tensorflow is not None - tf_major_ver = int(tensorflow.__version__.split(".")[0]) - - if output_path.suffix == ".zip": - output_path = output_path.with_suffix("") - zip_weights = True - else: - zip_weights = False - - if output_path.exists(): - raise ValueError(f"The ouptut directory at {output_path} must not exist.") - - if model.weights.keras_hdf5 is None: - raise ValueError("Missing Keras Hdf5 weights to convert from.") - - weight_spec = model.weights.keras_hdf5 - weight_path = download(weight_spec.source).path - - if weight_spec.tensorflow_version: - model_tf_major_ver = int(weight_spec.tensorflow_version.major) - if model_tf_major_ver != tf_major_ver: - raise RuntimeError( - f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" - ) - - if tf_major_ver == 1: - if len(model.inputs) != 1 or len(model.outputs) != 1: - raise NotImplementedError( - "Weight conversion for models with multiple inputs or outputs is not yet implemented." - ) - return _convert_tf1( - weight_path, - output_path, - model.inputs[0].id, - model.outputs[0].id, - zip_weights, - ) - else: - return _convert_tf2(weight_path, output_path, zip_weights) diff --git a/bioimageio/core/weight_converter/torch/__init__.py b/bioimageio/core/weight_converter/torch/__init__.py deleted file mode 100644 index 1b1ba526..00000000 --- a/bioimageio/core/weight_converter/torch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: torch weight converters diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py deleted file mode 100644 index 3935e1d1..00000000 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ /dev/null @@ -1,108 +0,0 @@ -# type: ignore # TODO: type -import warnings -from pathlib import Path -from typing import Any, List, Sequence, cast - -import numpy as np -from numpy.testing import assert_array_almost_equal - -from bioimageio.spec import load_description -from bioimageio.spec.common import InvalidDescr -from bioimageio.spec.model import v0_4, v0_5 - -from ...digest_spec import get_member_id, get_test_inputs -from ...weight_converter.torch._utils import load_torch_model - -try: - import torch -except ImportError: - torch = None - - -def add_onnx_weights( - model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", - *, - output_path: Path, - use_tracing: bool = True, - test_decimal: int = 4, - verbose: bool = False, - opset_version: "int | None" = None, -): - """Convert model weights from format 'pytorch_state_dict' to 'onnx'. - - Args: - source_model: model without onnx weights - opset_version: onnx opset version - use_tracing: whether to use tracing or scripting to export the onnx format - test_decimal: precision for testing whether the results agree - """ - if isinstance(model_spec, (str, Path)): - loaded_spec = load_description(Path(model_spec)) - if isinstance(loaded_spec, InvalidDescr): - raise ValueError(f"Bad resource description: {loaded_spec}") - if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError( - f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr" - ) - model_spec = loaded_spec - - state_dict_weights_descr = model_spec.weights.pytorch_state_dict - if state_dict_weights_descr is None: - raise ValueError( - "The provided model does not have weights in the pytorch state dict format" - ) - - assert torch is not None - with torch.no_grad(): - - sample = get_test_inputs(model_spec) - input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs] - input_tensors = [torch.from_numpy(ipt) for ipt in input_data] - model = load_torch_model(state_dict_weights_descr) - - expected_tensors = model(*input_tensors) - if isinstance(expected_tensors, torch.Tensor): - expected_tensors = [expected_tensors] - expected_outputs: List[np.ndarray[Any, Any]] = [ - out.numpy() for out in expected_tensors - ] - - if use_tracing: - torch.onnx.export( - model, - tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0], - str(output_path), - verbose=verbose, - opset_version=opset_version, - ) - else: - raise NotImplementedError - - try: - import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] - except ImportError: - msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." - warnings.warn(msg) - return - - # check the onnx model - sess = rt.InferenceSession(str(output_path)) - onnx_input_node_args = cast( - List[Any], sess.get_inputs() - ) # fixme: remove cast, try using rt.NodeArg instead of Any - onnx_inputs = { - input_name.name: inp - for input_name, inp in zip(onnx_input_node_args, input_data) - } - outputs = cast( - Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs) - ) # FIXME: remove cast - - try: - for exp, out in zip(expected_outputs, outputs): - assert_array_almost_equal(exp, out, decimal=test_decimal) - return 0 - except AssertionError as e: - msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" - warnings.warn(msg) - return 1 diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py deleted file mode 100644 index 5ca16069..00000000 --- a/bioimageio/core/weight_converter/torch/_torchscript.py +++ /dev/null @@ -1,146 +0,0 @@ -# type: ignore # TODO: type -from pathlib import Path -from typing import List, Sequence, Union - -import numpy as np -from numpy.testing import assert_array_almost_equal -from typing_extensions import Any, assert_never - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import Version - -from ._utils import load_torch_model - -try: - import torch -except ImportError: - torch = None - - -# FIXME: remove Any -def _check_predictions( - model: Any, - scripted_model: Any, - model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", - input_data: Sequence["torch.Tensor"], -): - assert torch is not None - - def _check(input_: Sequence[torch.Tensor]) -> None: - expected_tensors = model(*input_) - if isinstance(expected_tensors, torch.Tensor): - expected_tensors = [expected_tensors] - expected_outputs: List[np.ndarray[Any, Any]] = [ - out.numpy() for out in expected_tensors - ] - - output_tensors = scripted_model(*input_) - if isinstance(output_tensors, torch.Tensor): - output_tensors = [output_tensors] - outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors] - - try: - for exp, out in zip(expected_outputs, outputs): - assert_array_almost_equal(exp, out, decimal=4) - except AssertionError as e: - raise ValueError( - f"Results before and after weights conversion do not agree:\n {str(e)}" - ) - - _check(input_data) - - if len(model_spec.inputs) > 1: - return # FIXME: why don't we check multiple inputs? - - input_descr = model_spec.inputs[0] - if isinstance(input_descr, v0_4.InputTensorDescr): - if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape): - return - min_shape = input_descr.shape.min - step = input_descr.shape.step - else: - min_shape: List[int] = [] - step: List[int] = [] - for axis in input_descr.axes: - if isinstance(axis.size, v0_5.ParameterizedSize): - min_shape.append(axis.size.min) - step.append(axis.size.step) - elif isinstance(axis.size, int): - min_shape.append(axis.size) - step.append(0) - elif axis.size is None: - raise NotImplementedError( - f"Can't verify inputs that don't specify their shape fully: {axis}" - ) - elif isinstance(axis.size, v0_5.SizeReference): - raise NotImplementedError(f"Can't handle axes like '{axis}' yet") - else: - assert_never(axis.size) - - half_step = [st // 2 for st in step] - max_steps = 4 - - # check that input and output agree for decreasing input sizes - for step_factor in range(1, max_steps + 1): - slice_ = tuple( - slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) - for st in half_step - ) - this_input = [inp[slice_] for inp in input_data] - this_shape = this_input[0].shape - if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): - raise ValueError( - f"Mismatched shapes: {this_shape}. Expected at least {min_shape}" - ) - _check(this_input) - - -def convert_weights_to_torchscript( - model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], - output_path: Path, - use_tracing: bool = True, -) -> v0_5.TorchscriptWeightsDescr: - """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. - - Args: - model_descr: location of the resource for the input bioimageio model - output_path: where to save the torchscript weights - use_tracing: whether to use tracing or scripting to export the torchscript format - """ - - state_dict_weights_descr = model_descr.weights.pytorch_state_dict - if state_dict_weights_descr is None: - raise ValueError( - "The provided model does not have weights in the pytorch state dict format" - ) - - input_data = model_descr.get_input_test_arrays() - - with torch.no_grad(): - input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] - - model = load_torch_model(state_dict_weights_descr) - - # FIXME: remove Any - if use_tracing: - scripted_model: Any = torch.jit.trace(model, input_data) - else: - scripted_model: Any = torch.jit.script(model) - - _check_predictions( - model=model, - scripted_model=scripted_model, - model_spec=model_descr, - input_data=input_data, - ) - - # save the torchscript model - scripted_model.save( - str(output_path) - ) # does not support Path, so need to cast to str - - return v0_5.TorchscriptWeightsDescr( - source=output_path, - pytorch_version=Version(torch.__version__), - parent="pytorch_state_dict", - ) diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py deleted file mode 100644 index 01df0747..00000000 --- a/bioimageio/core/weight_converter/torch/_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Union - -from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download - -try: - import torch -except ImportError: - torch = None - - -# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too -# and for each weight format -def load_torch_model( # pyright: ignore[reportUnknownParameterType] - node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], -): - assert torch is not None - model = ( # pyright: ignore[reportUnknownVariableType] - PytorchModelAdapter.get_network(node) - ) - state = torch.load(download(node.source).path, map_location="cpu") - model.load_state_dict(state) # FIXME: check incompatible keys? - return model.eval() # pyright: ignore[reportUnknownVariableType] diff --git a/bioimageio/core/weight_converters/__init__.py b/bioimageio/core/weight_converters/__init__.py new file mode 100644 index 00000000..31a91642 --- /dev/null +++ b/bioimageio/core/weight_converters/__init__.py @@ -0,0 +1,3 @@ +from ._add_weights import add_weights + +__all__ = ["add_weights"] diff --git a/bioimageio/core/weight_converters/_add_weights.py b/bioimageio/core/weight_converters/_add_weights.py new file mode 100644 index 00000000..978c8450 --- /dev/null +++ b/bioimageio/core/weight_converters/_add_weights.py @@ -0,0 +1,173 @@ +import traceback +from typing import Optional + +from loguru import logger +from pydantic import DirectoryPath + +from bioimageio.spec import ( + InvalidDescr, + load_model_description, + save_bioimageio_package_as_folder, +) +from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat + +from .._resource_tests import load_description_and_test + + +def add_weights( + model_descr: ModelDescr, + *, + output_path: DirectoryPath, + source_format: Optional[WeightsFormat] = None, + target_format: Optional[WeightsFormat] = None, + verbose: bool = False, +) -> Optional[ModelDescr]: + """Convert model weights to other formats and add them to the model description + + Args: + output_path: Path to save updated model package to. + source_format: convert from a specific weights format. + Default: choose automatically from any available. + target_format: convert to a specific weights format. + Default: attempt to convert to any missing format. + devices: Devices that may be used during conversion. + verbose: log more (error) output + + Returns: + - An updated model description if any converted weights were added. + - `None` if no conversion was possible. + """ + if not isinstance(model_descr, ModelDescr): + if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr): + raise TypeError( + f"Model format {model_descr.format} is not supported, please update" + + f" model to format {ModelDescr.implemented_format_version} first." + ) + + raise TypeError(type(model_descr)) + + # save model to local folder + output_path = save_bioimageio_package_as_folder( + model_descr, output_path=output_path + ) + # reload from local folder to make sure we do not edit the given model + _model_descr = load_model_description(output_path, perform_io_checks=False) + assert isinstance(_model_descr, ModelDescr) + model_descr = _model_descr + del _model_descr + + if source_format is None: + available = set(model_descr.weights.available_formats) + else: + available = {source_format} + + if target_format is None: + missing = set(model_descr.weights.missing_formats) + else: + missing = {target_format} + + originally_missing = set(missing) + + if "pytorch_state_dict" in available and "torchscript" in missing: + logger.info( + "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'." + ) + from .pytorch_to_torchscript import convert + + try: + torchscript_weights_path = output_path / "weights_torchscript.pt" + model_descr.weights.torchscript = convert( + model_descr, + output_path=torchscript_weights_path, + use_tracing=False, + ) + except Exception as e: + if verbose: + traceback.print_exception(e) + + logger.error(e) + else: + available.add("torchscript") + missing.discard("torchscript") + + if "pytorch_state_dict" in available and "torchscript" in missing: + logger.info( + "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing." + ) + from .pytorch_to_torchscript import convert + + try: + torchscript_weights_path = output_path / "weights_torchscript_traced.pt" + + model_descr.weights.torchscript = convert( + model_descr, + output_path=torchscript_weights_path, + use_tracing=True, + ) + except Exception as e: + if verbose: + traceback.print_exception(e) + + logger.error(e) + else: + available.add("torchscript") + missing.discard("torchscript") + + if "torchscript" in available and "onnx" in missing: + logger.info("Attempting to convert 'torchscript' weights to 'onnx'.") + from .torchscript_to_onnx import convert + + try: + onnx_weights_path = output_path / "weights.onnx" + model_descr.weights.onnx = convert( + model_descr, + output_path=onnx_weights_path, + ) + except Exception as e: + if verbose: + traceback.print_exception(e) + + logger.error(e) + else: + available.add("onnx") + missing.discard("onnx") + + if "pytorch_state_dict" in available and "onnx" in missing: + logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.") + from .pytorch_to_onnx import convert + + try: + onnx_weights_path = output_path / "weights.onnx" + + model_descr.weights.onnx = convert( + model_descr, + output_path=onnx_weights_path, + verbose=verbose, + ) + except Exception as e: + if verbose: + traceback.print_exception(e) + + logger.error(e) + else: + available.add("onnx") + missing.discard("onnx") + + if missing: + logger.warning( + f"Converting from any of the available weights formats {available} to any" + + f" of {missing} failed or is not yet implemented. Please create an issue" + + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose" + + " if you would like bioimageio.core to support a particular conversion." + ) + + if originally_missing == missing: + logger.warning("failed to add any converted weights") + return None + else: + logger.info("added weights formats {}", originally_missing - missing) + # resave model with updated rdf.yaml + _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path) + tested_model_descr = load_description_and_test(model_descr) + assert isinstance(tested_model_descr, ModelDescr) + return tested_model_descr diff --git a/bioimageio/core/weight_converters/_utils_onnx.py b/bioimageio/core/weight_converters/_utils_onnx.py new file mode 100644 index 00000000..3c45d245 --- /dev/null +++ b/bioimageio/core/weight_converters/_utils_onnx.py @@ -0,0 +1,15 @@ +from collections import defaultdict +from itertools import chain +from typing import DefaultDict, Dict + +from bioimageio.spec.model.v0_5 import ModelDescr + + +def get_dynamic_axes(model_descr: ModelDescr): + dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict) + for d in chain(model_descr.inputs, model_descr.outputs): + for i, ax in enumerate(d.axes): + if not isinstance(ax.size, int): + dynamic_axes[str(d.id)][i] = str(ax.id) + + return dynamic_axes diff --git a/bioimageio/core/weight_converters/keras_to_tensorflow.py b/bioimageio/core/weight_converters/keras_to_tensorflow.py new file mode 100644 index 00000000..ac8886e1 --- /dev/null +++ b/bioimageio/core/weight_converters/keras_to_tensorflow.py @@ -0,0 +1,183 @@ +import os +import shutil +from pathlib import Path +from typing import Union, no_type_check +from zipfile import ZipFile + +import tensorflow # pyright: ignore[reportMissingTypeStubs] + +from bioimageio.spec._internal.io import download +from bioimageio.spec._internal.version_type import Version +from bioimageio.spec.common import ZipPath +from bioimageio.spec.model.v0_5 import ( + InputTensorDescr, + ModelDescr, + OutputTensorDescr, + TensorflowSavedModelBundleWeightsDescr, +) + +from .. import __version__ +from ..io import ensure_unzipped + +try: + # try to build the tf model with the keras import from tensorflow + from tensorflow import keras # type: ignore +except Exception: + # if the above fails try to export with the standalone keras + import keras # pyright: ignore[reportMissingTypeStubs] + + +def convert( + model_descr: ModelDescr, output_path: Path +) -> TensorflowSavedModelBundleWeightsDescr: + """ + Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format. + + This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle, + which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x + and 2.x versions, with appropriate checks to ensure compatibility. + + Adapted from: + https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py + + Args: + model_descr: + The bioimage.io model description containing the model's metadata and weights. + output_path: + Path with .zip suffix (.zip is appended otherwise) to which a zip archive + with the TensorFlow SavedModel bundle will be saved. + Raises: + ValueError: + - If the specified `output_path` already exists. + - If the Keras HDF5 weights are missing in the model description. + RuntimeError: + If there is a mismatch between the TensorFlow version used by the model and the version installed. + NotImplementedError: + If the model has multiple inputs or outputs and TensorFlow 1.x is being used. + + Returns: + A descriptor object containing information about the converted TensorFlow SavedModel bundle. + """ + tf_major_ver = int(tensorflow.__version__.split(".")[0]) + + if output_path.suffix != ".zip": + output_path = output_path.with_suffix("") + + if output_path.exists(): + raise ValueError(f"The ouptut directory at {output_path} must not exist.") + + if model_descr.weights.keras_hdf5 is None: + raise ValueError("Missing Keras Hdf5 weights to convert from.") + + weight_spec = model_descr.weights.keras_hdf5 + weight_path = download(weight_spec.source).path + + if weight_spec.tensorflow_version: + model_tf_major_ver = int(weight_spec.tensorflow_version.major) + if model_tf_major_ver != tf_major_ver: + raise RuntimeError( + f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" + ) + + if tf_major_ver == 1: + if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1: + raise NotImplementedError( + "Weight conversion for models with multiple inputs or outputs is not yet implemented." + ) + + input_name = str( + d.id + if isinstance((d := model_descr.inputs[0]), InputTensorDescr) + else d.name + ) + output_name = str( + d.id + if isinstance((d := model_descr.outputs[0]), OutputTensorDescr) + else d.name + ) + return _convert_tf1( + ensure_unzipped(weight_path, Path("bioimageio_unzipped_tf_weights")), + output_path, + input_name, + output_name, + ) + else: + return _convert_tf2(weight_path, output_path) + + +def _convert_tf2( + keras_weight_path: Union[Path, ZipPath], output_path: Path +) -> TensorflowSavedModelBundleWeightsDescr: + model = keras.models.load_model(keras_weight_path) # type: ignore + model.export(output_path) # type: ignore + + output_path = _zip_model_bundle(output_path) + print("TensorFlow model exported to", output_path) + + return TensorflowSavedModelBundleWeightsDescr( + source=output_path, + parent="keras_hdf5", + tensorflow_version=Version(tensorflow.__version__), + comment=f"Converted with bioimageio.core {__version__}.", + ) + + +# adapted from +# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 +def _convert_tf1( + keras_weight_path: Path, + output_path: Path, + input_name: str, + output_name: str, +) -> TensorflowSavedModelBundleWeightsDescr: + + @no_type_check + def build_tf_model(): + keras_model = keras.models.load_model(keras_weight_path) + builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) + signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( + inputs={input_name: keras_model.input}, + outputs={output_name: keras_model.output}, + ) + + signature_def_map = { + tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: ( + signature + ) + } + + builder.add_meta_graph_and_variables( + keras.backend.get_session(), + [tensorflow.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map, + ) + builder.save() + + build_tf_model() + + output_path = _zip_model_bundle(output_path) + print("TensorFlow model exported to", output_path) + + return TensorflowSavedModelBundleWeightsDescr( + source=output_path, + parent="keras_hdf5", + tensorflow_version=Version(tensorflow.__version__), + comment=f"Converted with bioimageio.core {__version__}.", + ) + + +def _zip_model_bundle(model_bundle_folder: Path): + zipped_model_bundle = model_bundle_folder.with_suffix(".zip") + + with ZipFile(zipped_model_bundle, "w") as zip_obj: + for root, _, files in os.walk(model_bundle_folder): + for filename in files: + src = os.path.join(root, filename) + zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) + + try: + shutil.rmtree(model_bundle_folder) + except Exception: + print("TensorFlow bundled model was not removed after compression") + + return zipped_model_bundle diff --git a/bioimageio/core/weight_converters/pytorch_to_onnx.py b/bioimageio/core/weight_converters/pytorch_to_onnx.py new file mode 100644 index 00000000..72d819b1 --- /dev/null +++ b/bioimageio/core/weight_converters/pytorch_to_onnx.py @@ -0,0 +1,79 @@ +from pathlib import Path + +import torch + +from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr + +from .. import __version__ +from ..backends.pytorch_backend import load_torch_model +from ..digest_spec import get_member_id, get_test_inputs +from ..proc_setup import get_pre_and_postprocessing +from ._utils_onnx import get_dynamic_axes + + +def convert( + model_descr: ModelDescr, + output_path: Path, + *, + verbose: bool = False, + opset_version: int = 15, +) -> OnnxWeightsDescr: + """ + Convert model weights from the Torchscript state_dict format to the ONNX format. + + Args: + model_descr: + The model description object that contains the model and its weights. + output_path: + The file path where the ONNX model will be saved. + verbose: + If True, will print out detailed information during the ONNX export process. Defaults to False. + opset_version: + The ONNX opset version to use for the export. Defaults to 15. + + Raises: + ValueError: + If the provided model does not have weights in the PyTorch state_dict format. + + Returns: + A descriptor object that contains information about the exported ONNX weights. + """ + + state_dict_weights_descr = model_descr.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) + + sample = get_test_inputs(model_descr) + procs = get_pre_and_postprocessing( + model_descr, dataset_for_initial_statistics=[sample] + ) + procs.pre(sample) + inputs_numpy = [ + sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs + ] + inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy] + model = load_torch_model(state_dict_weights_descr, load_state=True) + with torch.no_grad(): + outputs_original_torch = model(*inputs_torch) + if isinstance(outputs_original_torch, torch.Tensor): + outputs_original_torch = [outputs_original_torch] + + _ = torch.onnx.export( + model, + tuple(inputs_torch), + str(output_path), + input_names=[str(d.id) for d in model_descr.inputs], + output_names=[str(d.id) for d in model_descr.outputs], + dynamic_axes=get_dynamic_axes(model_descr), + verbose=verbose, + opset_version=opset_version, + ) + + return OnnxWeightsDescr( + source=output_path, + parent="pytorch_state_dict", + opset_version=opset_version, + comment=f"Converted with bioimageio.core {__version__}.", + ) diff --git a/bioimageio/core/weight_converters/pytorch_to_torchscript.py b/bioimageio/core/weight_converters/pytorch_to_torchscript.py new file mode 100644 index 00000000..3d0f281c --- /dev/null +++ b/bioimageio/core/weight_converters/pytorch_to_torchscript.py @@ -0,0 +1,70 @@ +from pathlib import Path +from typing import Any, Tuple, Union + +import torch +from torch.jit import ScriptModule + +from bioimageio.spec._internal.version_type import Version +from bioimageio.spec.model.v0_5 import ModelDescr, TorchscriptWeightsDescr + +from .. import __version__ +from ..backends.pytorch_backend import load_torch_model + + +def convert( + model_descr: ModelDescr, + output_path: Path, + *, + use_tracing: bool = True, +) -> TorchscriptWeightsDescr: + """ + Convert model weights from the PyTorch `state_dict` format to TorchScript. + + Args: + model_descr: + The model description object that contains the model and its weights in the PyTorch `state_dict` format. + output_path: + The file path where the TorchScript model will be saved. + use_tracing: + Whether to use tracing or scripting to export the TorchScript format. + - `True`: Use tracing, which is recommended for models with straightforward control flow. + - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals). + + Raises: + ValueError: + If the provided model does not have weights in the PyTorch `state_dict` format. + + Returns: + A descriptor object that contains information about the exported TorchScript weights. + """ + state_dict_weights_descr = model_descr.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) + + input_data = model_descr.get_input_test_arrays() + + with torch.no_grad(): + input_data = [torch.from_numpy(inp) for inp in input_data] + model = load_torch_model(state_dict_weights_descr, load_state=True) + scripted_model: Union[ # pyright: ignore[reportUnknownVariableType] + ScriptModule, Tuple[Any, ...] + ] = ( + torch.jit.trace(model, input_data) + if use_tracing + else torch.jit.script(model) + ) + assert not isinstance(scripted_model, tuple), scripted_model + + scripted_model.save(output_path) + + return TorchscriptWeightsDescr( + source=output_path, + pytorch_version=Version(torch.__version__), + parent="pytorch_state_dict", + comment=( + f"Converted with bioimageio.core {__version__}" + + f" with use_tracing={use_tracing}." + ), + ) diff --git a/bioimageio/core/weight_converters/torchscript_to_onnx.py b/bioimageio/core/weight_converters/torchscript_to_onnx.py new file mode 100644 index 00000000..d58b47ab --- /dev/null +++ b/bioimageio/core/weight_converters/torchscript_to_onnx.py @@ -0,0 +1,84 @@ +from pathlib import Path + +import torch.jit + +from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr +from bioimageio.spec.utils import download + +from .. import __version__ +from ..digest_spec import get_member_id, get_test_inputs +from ..proc_setup import get_pre_and_postprocessing +from ._utils_onnx import get_dynamic_axes + + +def convert( + model_descr: ModelDescr, + output_path: Path, + *, + verbose: bool = False, + opset_version: int = 15, +) -> OnnxWeightsDescr: + """ + Convert model weights from the PyTorch state_dict format to the ONNX format. + + Args: + model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): + The model description object that contains the model and its weights. + output_path (Path): + The file path where the ONNX model will be saved. + verbose (bool, optional): + If True, will print out detailed information during the ONNX export process. Defaults to False. + opset_version (int, optional): + The ONNX opset version to use for the export. Defaults to 15. + Raises: + ValueError: + If the provided model does not have weights in the torchscript format. + + Returns: + v0_5.OnnxWeightsDescr: + A descriptor object that contains information about the exported ONNX weights. + """ + + torchscript_descr = model_descr.weights.torchscript + if torchscript_descr is None: + raise ValueError( + "The provided model does not have weights in the torchscript format" + ) + + sample = get_test_inputs(model_descr) + procs = get_pre_and_postprocessing( + model_descr, dataset_for_initial_statistics=[sample] + ) + procs.pre(sample) + inputs_numpy = [ + sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs + ] + inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy] + + weight_path = download(torchscript_descr).path + model = torch.jit.load(weight_path) # type: ignore + model.to("cpu") + model = model.eval() # type: ignore + + with torch.no_grad(): + outputs_original_torch = model(*inputs_torch) # type: ignore + if isinstance(outputs_original_torch, torch.Tensor): + outputs_original_torch = [outputs_original_torch] + + _ = torch.onnx.export( + model, # type: ignore + tuple(inputs_torch), + str(output_path), + input_names=[str(d.id) for d in model_descr.inputs], + output_names=[str(d.id) for d in model_descr.outputs], + dynamic_axes=get_dynamic_axes(model_descr), + verbose=verbose, + opset_version=opset_version, + ) + + return OnnxWeightsDescr( + source=output_path, + parent="torchscript", + opset_version=opset_version, + comment=f"Converted with bioimageio.core {__version__}.", + ) diff --git a/dev/env-wo-python.yaml b/dev/env-dev.yaml similarity index 59% rename from dev/env-wo-python.yaml rename to dev/env-dev.yaml index d8cba289..13378376 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-dev.yaml @@ -1,21 +1,23 @@ -# modified copy of env.yaml +# modified copy of env-full.yaml wo dependencies 'for model testing' name: core channels: - conda-forge - nodefaults - - pytorch # added + - pytorch dependencies: - - bioimageio.spec>=0.5.3.5 + - bioimageio.spec==0.5.4.1 - black # - crick # currently requires python<=3.9 - - filelock - h5py + - imagecodecs - imageio>=2.5 - jupyter - jupyter-black - - keras>=3.0 + - keras>=3.0,<4 - loguru + - matplotlib - numpy + - onnx - onnxruntime - packaging>=17.0 - pdoc @@ -27,16 +29,16 @@ dependencies: - pyright - pytest - pytest-cov - - pytest-xdist - # - python=3.9 # removed - - pytorch>=2.1 + # - python=3.11 # removed + - pytorch>=2.1,<3 - requests - rich - ruff - ruyaml + - tensorflow>=2,<3 - torchvision - tqdm - typing-extensions - - xarray + - xarray>=2024.01,<2025.3.0 - pip: - -e .. diff --git a/dev/env-full.yaml b/dev/env-full.yaml new file mode 100644 index 00000000..a9dc0132 --- /dev/null +++ b/dev/env-full.yaml @@ -0,0 +1,49 @@ +name: core-full +channels: + - conda-forge + - nodefaults + - pytorch +dependencies: + - bioimageio.spec==0.5.4.1 + - black + # - careamics # TODO: add careamics for model testing (currently pins pydantic to <2.9) + - cellpose # for model testing + # - crick # currently requires python<=3.9 + - h5py + - imagecodecs + - imageio>=2.5 + - jupyter + - jupyter-black + - keras>=3.0,<4 + - loguru + - matplotlib + - monai # for model testing + - numpy + - onnx + - onnxruntime + - packaging>=17.0 + - pdoc + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-cov + - python=3.11 # 3.12 not supported by cellpose->fastremap + - pytorch>=2.1,<3 + - requests + - rich + - ruff + - ruyaml + - segment-anything # for model testing + - tensorflow>=2,<3 + - timm # for model testing + - torchvision>=0.21 + - tqdm + - typing-extensions + - xarray>=2024.01,<2025.3.0 + - pip: + - git+https://github.com/ChaoningZhang/MobileSAM.git # for model testing + - -e .. diff --git a/dev/env-gpu.yaml b/dev/env-gpu.yaml new file mode 100644 index 00000000..7fc2123c --- /dev/null +++ b/dev/env-gpu.yaml @@ -0,0 +1,52 @@ +# version of enf-full for running on GPU +name: core-gpu +channels: + - conda-forge + - nodefaults +dependencies: + - bioimageio.spec==0.5.4.1 + - black + - cellpose # for model testing + # - crick # currently requires python<=3.9 + - h5py + - imagecodecs + - imageio>=2.5 + - jupyter + - jupyter-black + - keras>=3.0,<4 + - loguru + - matplotlib + - monai # for model testing + - numpy + - onnx + - packaging>=17.0 + - pdoc + - pip + - pre-commit + - psutil + - pydantic<2.9 + - pydantic-settings + - pyright + - pytest + - pytest-cov + - python=3.11 + - requests + - rich + - ruff + - ruyaml + - segment-anything # for model testing + - timm # for model testing + - tqdm + - typing-extensions + - xarray>=2024.01,<2025.3.0 + - pip: + # - tf2onnx # TODO: add tf2onnx + - --extra-index-url https://download.pytorch.org/whl/cu126 + - careamics # TODO: add careamics for model testing (currently pins pydantic to <2.9) + - git+https://github.com/ChaoningZhang/MobileSAM.git # for model testing + - onnxruntime-gpu + - tensorflow + - torch + - torchaudio + - torchvision>=0.21 + - -e .. diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index 22353103..6fc6597a 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -1,20 +1,23 @@ -# manipulated copy of env.yaml -name: core38 +# manipulated copy of env-full.yaml wo dependencies 'for model testing' for python 3.8 +name: core-py38 channels: - conda-forge - nodefaults + - pytorch dependencies: - - bioimageio.spec>=0.5.3.5 + - bioimageio.spec==0.5.4.1 - black - crick # uncommented - - filelock - h5py + - imagecodecs - imageio>=2.5 - jupyter - jupyter-black - # - keras>=3.0 # removed + # - keras>=3.0,<4 # removed - loguru + - matplotlib - numpy + - onnx - onnxruntime - packaging>=17.0 - pdoc @@ -26,16 +29,16 @@ dependencies: - pyright - pytest - pytest-cov - - pytest-xdist - python=3.8 # changed - - pytorch>=2.1 + - pytorch>=2.1,<3 - requests - rich - ruff - ruyaml + # - tensorflow>=2,<3 removed - torchvision - tqdm - typing-extensions - - xarray + - xarray>=2023.01,<2025.3.0 - pip: - -e .. diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml deleted file mode 100644 index 0df6fd07..00000000 --- a/dev/env-tf.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# modified copy of env.yaml -name: core-tf # changed -channels: - - conda-forge - - nodefaults -dependencies: - - bioimageio.spec>=0.5.3.5 - - black - # - crick # currently requires python<=3.9 - - filelock - - h5py - - imageio>=2.5 - - jupyter - - jupyter-black - - keras>=2.15 # changed - - loguru - - numpy - - onnxruntime - - packaging>=17.0 - - pdoc - - pip - - pre-commit - - psutil - - pydantic - - pydantic-settings - - pyright - - pytest - - pytest-cov - - pytest-xdist - # - python=3.9 # removed - # - pytorch>=2.1 # removed - - requests - - rich - # - ruff # removed - - ruyaml - - tensorflow>=2.15 # added - # - torchvision # removed - - tqdm - - typing-extensions - - xarray - - pip: - - -e .. diff --git a/dev/env.yaml b/dev/env.yaml deleted file mode 100644 index 20d60a18..00000000 --- a/dev/env.yaml +++ /dev/null @@ -1,41 +0,0 @@ -name: core -channels: - - conda-forge -dependencies: - - bioimageio.spec>=0.5.3.5 - - black - # - crick # currently requires python<=3.9 - - filelock - - h5py - - imageio>=2.5 - - jupyter - - jupyter-black - - ipykernel - - matplotlib - - keras>=3.0 - - loguru - - numpy - - onnxruntime - - packaging>=17.0 - - pdoc - - pip - - pre-commit - - psutil - - pydantic - - pydantic-settings - - pyright - - pytest - - pytest-cov - - pytest-xdist - - python=3.9 - - pytorch>=2.1 - - requests - - rich - - ruff - - ruyaml - - torchvision - - tqdm - - typing-extensions - - xarray - - pip: - - -e .. diff --git a/presentations/create_ambitious_sloth.ipynb b/presentations/create_ambitious_sloth.ipynb index 171b30db..8cda7fec 100644 --- a/presentations/create_ambitious_sloth.ipynb +++ b/presentations/create_ambitious_sloth.ipynb @@ -465,7 +465,7 @@ } ], "source": [ - "pytorch_weights = torch.load(root / \"weights.pt\", weights_only=False)\n", + "pytorch_weights = torch.load(root / \"weights.pt\", weights_only=True)\n", "pprint([(k, tuple(v.shape)) for k, v in pytorch_weights.items()][:4] + [\"...\"])" ] }, diff --git a/pyproject.toml b/pyproject.toml index 91cd2cbc..5d58fe72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 88 -extend_exclude = "/presentations/" +extend-exclude = "/presentations/" target-version = ["py38", "py39", "py310", "py311", "py312"] preview = true @@ -8,6 +8,7 @@ preview = true exclude = [ "**/__pycache__", "**/node_modules", + "dogfood", "presentations", "scripts/pdoc/original.py", "scripts/pdoc/patched.py", @@ -39,7 +40,8 @@ typeCheckingMode = "strict" useLibraryCodeForTypes = true [tool.pytest.ini_options] -addopts = "--cov=bioimageio --cov-report=xml -n auto --capture=no --doctest-modules --failed-first" +addopts = "--cov bioimageio --cov-report xml --cov-append --capture no --doctest-modules --failed-first --ignore dogfood --ignore bioimageio/core/backends --ignore bioimageio/core/weight_converters" +testpaths = ["bioimageio/core", "tests"] [tool.ruff] line-length = 88 diff --git a/scripts/show_diff.py b/scripts/show_diff.py index 1b0163bb..3e273d79 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -2,14 +2,14 @@ from pathlib import Path from tempfile import TemporaryDirectory -import pooch +import pooch # pyright: ignore[reportMissingTypeStubs] from bioimageio.core import load_description, save_bioimageio_yaml_only if __name__ == "__main__": rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/v0_4_9.bioimageio.yaml" - local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore + local_source = Path(pooch.retrieve(rdf_source, None)) model_as_is = load_description(rdf_source, format_version="discover") model_latest = load_description(rdf_source, format_version="latest") print(model_latest.validation_summary) diff --git a/setup.py b/setup.py index 2465ff7e..0755ff2d 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,9 @@ ], packages=find_namespace_packages(exclude=["tests"]), install_requires=[ - "bioimageio.spec ==0.5.3.5", + "bioimageio.spec ==0.5.4.1", "h5py", + "imagecodecs", "imageio>=2.10", "loguru", "numpy", @@ -41,33 +42,37 @@ "ruyaml", "tqdm", "typing-extensions", - "xarray<2025.3.0", + "xarray>=2023.01,<2025.3.0", ], include_package_data=True, extras_require={ - "pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"], - "tensorflow": ["tensorflow", "keras>=2.15"], + "pytorch": ( + pytorch_deps := ["torch>=1.6,<3", "torchvision>=0.21", "keras>=3.0,<4"] + ), + "tensorflow": ["tensorflow", "keras>=2.15,<4"], "onnx": ["onnxruntime"], - "dev": [ - "black", - # "crick", # currently requires python<=3.9 - "filelock", - "jupyter", - "jupyter-black", - "matplotlib", - "keras>=3.0", - "onnxruntime", - "packaging>=17.0", - "pre-commit", - "pdoc", - "psutil", # parallel pytest with 'pytest -n auto' - "pyright", - "pytest-cov", - "pytest-xdist", # parallel pytest - "pytest", - "torch>=1.6", - "torchvision", - ], + "dev": ( + pytorch_deps + + [ + "black", + "cellpose", # for model testing + "jupyter-black", + "jupyter", + "matplotlib", + "monai", # for model testing + "onnx", + "onnxruntime", + "packaging>=17.0", + "pdoc", + "pre-commit", + "pyright==1.1.396", + "pytest-cov", + "pytest", + "segment-anything", # for model testing + "timm", # for model testing + # "crick", # currently requires python<=3.9 + ] + ), }, project_urls={ "Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 253ade2f..32880b05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,23 @@ from __future__ import annotations import subprocess -import warnings from itertools import chain from typing import Dict, List from loguru import logger from pytest import FixtureRequest, fixture +from bioimageio.core import enable_determinism from bioimageio.spec import __version__ as bioimageio_spec_version +enable_determinism() + + try: import torch torch_version = tuple(map(int, torch.__version__.split(".")[:2])) - logger.warning(f"detected torch version {torch_version}.x") + logger.warning("detected torch version {}", torch.__version__) except ImportError: torch = None torch_version = None @@ -29,7 +32,7 @@ try: import tensorflow # type: ignore - tf_major_version = int(tensorflow.__version__.split(".")[0]) # type: ignore + tf_major_version = int(tensorflow.__version__.split(".")[0]) except ImportError: tensorflow = None tf_major_version = None @@ -41,56 +44,50 @@ skip_tensorflow = tensorflow is None -warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") +logger.warning("testing with bioimageio.spec {}", bioimageio_spec_version) + +EXAMPLE_DESCRIPTIONS = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/" # TODO: use models from new collection on S3 MODEL_SOURCES: Dict[str, str] = { - "hpa_densenet": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" - ), + "hpa_densenet": "polite-pig/1.1", "stardist": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" - "/stardist_example_model/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/stardist_example_model/v0_4.bioimageio.yaml" ), "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "upsample_test_model/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/upsample_test_model/v0_4.bioimageio.yaml" ), "stardist_wrong_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape.yaml" + EXAMPLE_DESCRIPTIONS + "models/stardist_example_model/rdf_wrong_shape.yaml" ), "stardist_wrong_shape2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" + EXAMPLE_DESCRIPTIONS + + "models/stardist_example_model/rdf_wrong_shape2_v0_4.yaml" ), "unet2d_diff_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_diff_output_shape/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/unet2d_diff_output_shape/bioimageio.yaml" ), "unet2d_expand_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + + "models/unet2d_nuclei_broad/expand_output_shape.bioimageio.yaml" ), "unet2d_fixed_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_fixed_shape/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/unet2d_fixed_shape/v0_4.bioimageio.yaml" ), "unet2d_keras_tf2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf2/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/unet2d_keras_tf2/v0_4.bioimageio.yaml" ), "unet2d_keras": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/unet2d_keras_tf/v0_4.bioimageio.yaml" ), "unet2d_multi_tensor": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_multi_tensor/v0_4.bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/unet2d_multi_tensor/bioimageio.yaml" + ), + "unet2d_nuclei_broad_model_old": ( + EXAMPLE_DESCRIPTIONS + "models/unet2d_nuclei_broad/v0_4_9.bioimageio.yaml" ), "unet2d_nuclei_broad_model": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/bioimageio.yaml" + EXAMPLE_DESCRIPTIONS + "models/unet2d_nuclei_broad/bioimageio.yaml" ), } @@ -244,6 +241,14 @@ def unet2d_nuclei_broad_model(request: FixtureRequest): return MODEL_SOURCES[request.param] +# written as model group to automatically skip on missing torch +@fixture( + scope="session", params=[] if skip_torch else ["unet2d_nuclei_broad_model_old"] +) +def unet2d_nuclei_broad_model_old(request: FixtureRequest): + return MODEL_SOURCES[request.param] + + # written as model group to automatically skip on missing torch @fixture(scope="session", params=[] if skip_torch else ["unet2d_diff_output_shape"]) def unet2d_diff_output_shape(request: FixtureRequest): diff --git a/tests/weight_converter/test_add_weights.py b/tests/test_add_weights.py similarity index 100% rename from tests/weight_converter/test_add_weights.py rename to tests/test_add_weights.py diff --git a/tests/test_any_model_fixture.py b/tests/test_any_model_fixture.py index a4cc1bce..77225f18 100644 --- a/tests/test_any_model_fixture.py +++ b/tests/test_any_model_fixture.py @@ -3,4 +3,4 @@ def test_model(any_model: str): summary = load_description_and_validate_format_only(any_model) - assert summary.status == "passed", summary.format() + assert summary.status == "valid-format", summary.display() diff --git a/tests/test_bioimageio_collection.py b/tests/test_bioimageio_collection.py new file mode 100644 index 00000000..92f3dd5c --- /dev/null +++ b/tests/test_bioimageio_collection.py @@ -0,0 +1,113 @@ +import os +from typing import Any, Collection, Dict, Iterable, Mapping, Tuple + +import pytest +import requests +from pydantic import HttpUrl + +from bioimageio.spec import InvalidDescr +from bioimageio.spec.common import Sha256 +from tests.utils import ParameterSet, expensive_test + +BASE_URL = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/" + + +def _get_latest_rdf_sources(): + entries: Any = requests.get(BASE_URL + "all_versions.json").json()["entries"] + ret: Dict[str, Tuple[HttpUrl, Sha256]] = {} + for entry in entries: + version = entry["versions"][0] + ret[f"{entry['concept']}/{version['v']}"] = ( + HttpUrl(version["source"]), # pyright: ignore[reportCallIssue] + Sha256(version["sha256"]), + ) + + return ret + + +ALL_LATEST_RDF_SOURCES: Mapping[str, Tuple[HttpUrl, Sha256]] = _get_latest_rdf_sources() + + +def yield_bioimageio_yaml_urls() -> Iterable[ParameterSet]: + for descr_url, sha in ALL_LATEST_RDF_SOURCES.values(): + key = ( + str(descr_url) + .replace(BASE_URL, "") + .replace("/files/rdf.yaml", "") + .replace("/files/bioimageio.yaml", "") + ) + yield pytest.param(descr_url, sha, key, id=key) + + +KNOWN_INVALID: Collection[str] = { + "affable-shark/1.1", # onnx weights expect fixed input shape + "affectionate-cow/0.1.0", # custom dependencies + "ambitious-sloth/1.2", # requires inferno + "committed-turkey/1.2", # error deserializing VarianceScaling + "creative-panda/1", # error deserializing Conv2D + "dazzling-spider/0.1.0", # requires careamics + "discreet-rooster/1", # error deserializing VarianceScaling + "discreete-rooster/1", # error deserializing VarianceScaling + "dynamic-t-rex/1", # needs update to 0.5 for scale_linear with axes processing + "easy-going-sauropod/1", # CPU implementation of Conv3D currently only supports the NHWC tensor format. + "efficient-chipmunk/1", # needs plantseg + "famous-fish/0.1.0", # list index out of range `fl[3]` + "greedy-whale/1", # batch size is actually limited to 1 + "happy-elephant/0.1.0", # list index out of range `fl[3]` + "happy-honeybee/0.1.0", # requires biapy + "heroic-otter/0.1.0", # requires biapy + "humorous-crab/1", # batch size is actually limited to 1 + "humorous-fox/0.1.0", # requires careamics + "humorous-owl/1", # error deserializing GlorotUniform + "idealistic-turtle/0.1.0", # requires biapy + "impartial-shark/1", # error deserializing VarianceScaling + "intelligent-lion/0.1.0", # requires biapy + "joyful-deer/1", # needs update to 0.5 for scale_linear with axes processing + "merry-water-buffalo/0.1.0", # requires biapy + "naked-microbe/1", # unknown layer Convolution2D + "noisy-ox/1", # batch size is actually limited to 1 + "non-judgemental-eagle/1", # error deserializing GlorotUniform + "straightforward-crocodile/1", # needs update to 0.5 for scale_linear with axes processing + "stupendous-sheep/1.1", # requires relativ import of attachment + "stupendous-sheep/1.2", + "venomous-swan/0.1.0", # requires biapy + "wild-rhino/0.1.0", # requires careamics +} + + +@pytest.mark.parametrize("descr_url,sha,key", list(yield_bioimageio_yaml_urls())) +def test_rdf_format_to_populate_cache( + descr_url: HttpUrl, + sha: Sha256, + key: str, +): + """this test is redundant if `test_rdf` runs, but is used in the CI to populate the cache""" + if os.environ.get("BIOIMAGEIO_POPULATE_CACHE") != "1": + pytest.skip("only runs in CI to populate cache") + + if key in KNOWN_INVALID: + pytest.skip("known failure") + + from bioimageio.core import load_description + + _ = load_description(descr_url, sha256=sha, perform_io_checks=True) + + +@expensive_test +@pytest.mark.parametrize("descr_url,sha,key", list(yield_bioimageio_yaml_urls())) +def test_rdf( + descr_url: HttpUrl, + sha: Sha256, + key: str, +): + if key in KNOWN_INVALID: + pytest.skip("known failure") + + from bioimageio.core import load_description_and_test + + descr = load_description_and_test(descr_url, sha256=sha, stop_early=True) + + assert not isinstance(descr, InvalidDescr), descr.validation_summary.display() + assert ( + descr.validation_summary.status == "passed" + ), descr.validation_summary.display() diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index 921ecd9c..2418baa5 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -8,7 +8,7 @@ def test_bioimageio_spec_version(conda_cmd: Optional[str]): if conda_cmd is None: - pytest.skip("requires mamba") + pytest.skip("requires conda") from importlib.metadata import metadata diff --git a/tests/test_cli.py b/tests/test_cli.py index e0828ac6..203677ec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import subprocess +from pathlib import Path from typing import Any, List, Sequence import pytest @@ -36,11 +37,30 @@ def run_subprocess( ], ["test", "unet2d_nuclei_broad_model"], ["predict", "--example", "unet2d_nuclei_broad_model"], + ["update-format", "unet2d_nuclei_broad_model_old"], + ["add-weights", "unet2d_nuclei_broad_model", "tmp_path"], + ["update-hashes", "unet2d_nuclei_broad_model_old"], + ["update-hashes", "unet2d_nuclei_broad_model_old", "--output=stdout"], ], ) -def test_cli(args: List[str], unet2d_nuclei_broad_model: str): +def test_cli( + args: List[str], + unet2d_nuclei_broad_model: str, + unet2d_nuclei_broad_model_old: str, + tmp_path: Path, +): resolved_args = [ - str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg + ( + unet2d_nuclei_broad_model + if arg == "unet2d_nuclei_broad_model" + else ( + unet2d_nuclei_broad_model_old + if arg == "unet2d_nuclei_broad_model_old" + else ( + arg.replace("tmp_path", str(tmp_path)) if "tmp_path" in arg else arg + ) + ) + ) for arg in args ] ret = run_subprocess(["bioimageio", *resolved_args]) diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index a0a85f5d..08e9f094 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -2,12 +2,15 @@ from numpy.testing import assert_array_almost_equal +from bioimageio.core.common import SupportedWeightsFormat from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 -from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat +from bioimageio.spec.model.v0_5 import ModelDescr -def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat): +def _test_prediction_pipeline( + model_package: Path, weights_format: SupportedWeightsFormat +): from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 0e241df1..0d2ff9b7 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -1,17 +1,14 @@ from pathlib import Path +import pytest from numpy.testing import assert_array_almost_equal -from bioimageio.core.utils.testing import skip_on +from bioimageio.core.common import SupportedWeightsFormat from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 -from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat +from bioimageio.spec.model.v0_5 import ModelDescr -class TooFewDevicesException(Exception): - pass - - -def _test_device_management(model_package: Path, weight_format: WeightsFormat): +def _test_device_management(model_package: Path, weight_format: SupportedWeightsFormat): import torch from bioimageio.core import load_description @@ -19,7 +16,7 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs if not hasattr(torch, "cuda") or torch.cuda.device_count() == 0: - raise TooFewDevicesException("Need at least one cuda device for this test") + pytest.skip("Need at least one cuda device for this test") bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) @@ -52,26 +49,21 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert_array_almost_equal(out, exp, decimal=4) -@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_torch(any_torch_model: Path): _test_device_management(any_torch_model, "pytorch_state_dict") -@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_torchscript(any_torchscript_model: Path): _test_device_management(any_torchscript_model, "torchscript") -@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_onnx(any_onnx_model: Path): _test_device_management(any_onnx_model, "onnx") -@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_tensorflow(any_tensorflow_model: Path): _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") -@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_keras(any_keras_model: Path): _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index e408d220..be87f54b 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -21,15 +21,17 @@ def tid(): def test_scale_linear(tid: MemberId): from bioimageio.core.proc_ops import ScaleLinear - offset = xr.DataArray([1, 2, 42], dims=("c")) - gain = xr.DataArray([1, 2, 3], dims=("c")) - data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) + offset = xr.DataArray([1, 2, 42], dims=("channel",)) + gain = xr.DataArray([1, 2, 3], dims=("channel",)) + data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "channel")) sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) op(sample) - expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c")) + expected = xr.DataArray( + np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "channel") + ) xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7) @@ -84,10 +86,10 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId): op = FixedZeroMeanUnitVariance( tid, tid, - mean=xr.DataArray([3, 4, 5], dims=("c")), - std=xr.DataArray([2.44948974, 2.44948974, 2.44948974], dims=("c")), + mean=xr.DataArray([3, 4, 5], dims=("channel",)), + std=xr.DataArray([2.44948974, 2.44948974, 2.44948974], dims=("channel",)), ) - data = xr.DataArray(np.arange(9).reshape((1, 3, 3)), dims=("b", "c", "x")) + data = xr.DataArray(np.arange(9).reshape((1, 3, 3)), dims=("batch", "channel", "x")) expected = xr.DataArray( np.array( [ @@ -98,17 +100,33 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId): ] ] ), - dims=("b", "c", "x"), + dims=("batch", "channel", "x"), ) sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op(sample) xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7) +def test_zero_mean_unit_variance_fixed2(tid: MemberId): + from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance + + np_data = np.arange(9).reshape(3, 3) + mean = float(np_data.mean()) + std = float(np_data.mean()) + eps = 1.0e-7 + op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) + + data = xr.DataArray(np_data, dims=("x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) + op(sample) + xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7) + + def test_zero_mean_unit_across_axes(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance - data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) + data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("channel", "x", "y")) op = ZeroMeanUnitVariance( tid, @@ -120,33 +138,18 @@ def test_zero_mean_unit_across_axes(tid: MemberId): sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.concat( - [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c" + [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], + dim="channel", ) op(sample) xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7) -def test_zero_mean_unit_variance_fixed2(tid: MemberId): - from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance - - np_data = np.arange(9).reshape(3, 3) - mean = float(np_data.mean()) - std = float(np_data.mean()) - eps = 1.0e-7 - op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) - - data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) - expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) - op(sample) - xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-5, atol=1e-7) - - def test_binarize(tid: MemberId): from bioimageio.core.proc_ops import Binarize op = Binarize(tid, tid, threshold=14) - data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) + data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "channel")) sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) expected = xr.zeros_like(data) expected[{"x": slice(1, None)}] = 1 @@ -158,7 +161,7 @@ def test_binarize2(tid: MemberId): from bioimageio.core.proc_ops import Binarize shape = (3, 32, 32) - axes = ("c", "y", "x") + axes = ("channel", "y", "x") np_data = np.random.rand(*shape) data = xr.DataArray(np_data, dims=axes) @@ -188,7 +191,7 @@ def test_clip(tid: MemberId): def test_combination_of_op_steps_with_dims_specified(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance - data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) + data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("channel", "x", "y")) sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op = ZeroMeanUnitVariance( tid, @@ -219,7 +222,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: MemberId): ], ] ), - dims=("c", "x", "y"), + dims=("channel", "x", "y"), ) op(sample) @@ -239,7 +242,7 @@ def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]): from bioimageio.core.proc_ops import ScaleMeanVariance shape = (3, 32, 46) - ipt_axes = ("c", "y", "x") + ipt_axes = ("channel", "y", "x") np_data = np.random.rand(*shape) ipt_data = xr.DataArray(np_data, dims=ipt_axes) ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) @@ -268,7 +271,7 @@ def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str]) axes = None if axes_str is None else tuple(map(AxisId, axes_str)) shape = (3, 32, 46) - ipt_axes = ("c", "y", "x") + ipt_axes = ("channel", "y", "x") np_data = np.random.rand(*shape) ipt_data = xr.DataArray(np_data, dims=ipt_axes) @@ -334,7 +337,7 @@ def test_scale_range_axes(tid: MemberId): op = ScaleRange(tid, tid, lower_quantile, upper_quantile, eps=eps) np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") - data = Tensor.from_xarray(xr.DataArray(np_data, dims=("c", "x", "y"))) + data = Tensor.from_xarray(xr.DataArray(np_data, dims=("channel", "x", "y"))) sample = Sample(members={tid: data}, stat={}, id=None) p_low_direct = lower_quantile.compute(sample) @@ -352,7 +355,7 @@ def test_scale_range_axes(tid: MemberId): np.testing.assert_allclose(p_up_expected.squeeze(), sample.stat[upper_quantile]) exp_data = (np_data - p_low_expected) / (p_up_expected - p_low_expected + eps) - expected = xr.DataArray(exp_data, dims=("c", "x", "y")) + expected = xr.DataArray(exp_data, dims=("channel", "x", "y")) op(sample) # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct @@ -363,7 +366,7 @@ def test_sigmoid(tid: MemberId): from bioimageio.core.proc_ops import Sigmoid shape = (3, 32, 32) - axes = ("c", "y", "x") + axes = ("channel", "y", "x") np_data = np.random.rand(*shape) data = xr.DataArray(np_data, dims=axes) sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 203ca64b..f4eca96b 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -1,15 +1,4 @@ -from typing import Literal - -import pytest - -from bioimageio.spec import InvalidDescr - - -@pytest.mark.parametrize("mode", ["seed_only", "full"]) -def test_enable_determinism(mode: Literal["seed_only", "full"]): - from bioimageio.core import enable_determinism - - enable_determinism(mode) +from bioimageio.spec import InvalidDescr, ValidationContext def test_error_for_wrong_shape(stardist_wrong_shape: str): @@ -38,15 +27,10 @@ def test_error_for_wrong_shape2(stardist_wrong_shape2: str): def test_test_model(any_model: str): from bioimageio.core._resource_tests import test_model - summary = test_model(any_model) - assert summary.status == "passed", summary.format() - - -def test_test_resource(any_model: str): - from bioimageio.core._resource_tests import test_description + with ValidationContext(raise_errors=True): + summary = test_model(any_model) - summary = test_description(any_model) - assert summary.status == "passed", summary.format() + assert summary.status == "passed", summary.display() def test_loading_description_multiple_times(unet2d_nuclei_broad_model: str): diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index 57e86c5a..dd2823b6 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Optional, Tuple import numpy as np import pytest @@ -12,6 +12,9 @@ DatasetMean, DatasetStd, DatasetVar, + SampleMean, + SampleStd, + SampleVar, ) from bioimageio.core.tensor import Tensor @@ -27,18 +30,53 @@ def create_random_dataset(tid: MemberId, axes: Tuple[AxisId, ...]): return Tensor(data, dims=axes), ds +@pytest.mark.parametrize( + "axes", + [ + (AxisId("x"), AxisId("y")), + (AxisId("channel"), AxisId("y")), + ], +) +def test_sample_mean_var_std_calculator(axes: Optional[Tuple[AxisId, ...]]): + tid = MemberId("tensor") + d_axes = tuple(map(AxisId, ("batch", "channel", "x", "y"))) + data, ds = create_random_dataset(tid, d_axes) + expected_mean = data[0:1].mean(axes) + expected_var = data[0:1].var(axes) + expected_std = data[0:1].std(axes) + + calc = MeanVarStdCalculator(tid, axes=axes) + + actual = calc.compute(ds[0]) + actual_mean = actual[SampleMean(member_id=tid, axes=axes)] + actual_var = actual[SampleVar(member_id=tid, axes=axes)] + actual_std = actual[SampleStd(member_id=tid, axes=axes)] + + assert_allclose( + actual_mean if isinstance(actual_mean, (int, float)) else actual_mean.data, + expected_mean.data, + ) + assert_allclose( + actual_var if isinstance(actual_var, (int, float)) else actual_var.data, + expected_var.data, + ) + assert_allclose( + actual_std if isinstance(actual_std, (int, float)) else actual_std.data, + expected_std.data, + ) + + @pytest.mark.parametrize( "axes", [ None, - ("x", "y"), - ("channel", "y"), + (AxisId("batch"), AxisId("channel"), AxisId("x"), AxisId("y")), ], ) -def test_mean_var_std_calculator(axes: Union[None, str, Tuple[str, ...]]): +def test_dataset_mean_var_std_calculator(axes: Optional[Tuple[AxisId, ...]]): tid = MemberId("tensor") - axes = tuple(map(AxisId, ("batch", "channel", "x", "y"))) - data, ds = create_random_dataset(tid, axes) + d_axes = tuple(map(AxisId, ("batch", "channel", "x", "y"))) + data, ds = create_random_dataset(tid, d_axes) expected_mean = data.mean(axes) expected_var = data.var(axes) expected_std = data.std(axes) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 33163077..c57980bd 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,3 +1,5 @@ +from typing import Sequence + import numpy as np import pytest import xarray as xr @@ -8,9 +10,19 @@ @pytest.mark.parametrize( "axes", - ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], + [ + "yx", + "xy", + "cyx", + "yxc", + "bczyx", + "xyz", + "xyzc", + "bzyxc", + ("batch", "channel", "x", "y"), + ], ) -def test_transpose_tensor_2d(axes: str): +def test_transpose_tensor_2d(axes: Sequence[str]): tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) @@ -19,9 +31,18 @@ def test_transpose_tensor_2d(axes: str): @pytest.mark.parametrize( "axes", - ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], + [ + "zyx", + "cyzx", + "yzixc", + "bczyx", + "xyz", + "xyzc", + "bzyxtc", + ("batch", "channel", "x", "y", "z"), + ], ) -def test_transpose_tensor_3d(axes: str): +def test_transpose_tensor_3d(axes: Sequence[str]): tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) assert transposed.ndim == len(axes) @@ -39,3 +60,8 @@ def test_crop_and_pad(): def test_some_magic_ops(): tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) assert tensor + 2 == 2 + tensor + + +def test_shape_attributes(): + tensor = Tensor.from_numpy(np.random.rand(1, 2, 25, 26), dims=None) + assert tensor.shape_tuple == tensor.shape diff --git a/tests/test_weight_converters.py b/tests/test_weight_converters.py new file mode 100644 index 00000000..1bf65782 --- /dev/null +++ b/tests/test_weight_converters.py @@ -0,0 +1,108 @@ +# type: ignore # TODO enable type checking +import os +import zipfile +from pathlib import Path + +import pytest + +from bioimageio.spec import load_description +from bioimageio.spec.model import v0_5 + + +def test_pytorch_to_torchscript(any_torch_model, tmp_path): + from bioimageio.core import test_model + from bioimageio.core.weight_converters.pytorch_to_torchscript import convert + + model_descr = load_description(any_torch_model, perform_io_checks=False) + if model_descr.implemented_format_version_tuple[:2] == (0, 4): + pytest.skip("cannot convert to old 0.4 format") + + out_path = tmp_path / "weights.pt" + ret_val = convert(model_descr, out_path) + assert out_path.exists() + assert isinstance(ret_val, v0_5.TorchscriptWeightsDescr) + assert ret_val.source == out_path + model_descr.weights.torchscript = ret_val + summary = test_model(model_descr, weight_format="torchscript") + assert summary.status == "passed", summary.display() + + +def test_pytorch_to_onnx(convert_to_onnx, tmp_path): + from bioimageio.core import test_model + from bioimageio.core.weight_converters.pytorch_to_onnx import convert + + model_descr = load_description(convert_to_onnx, format_version="latest") + out_path = tmp_path / "weights.onnx" + opset_version = 15 + ret_val = convert( + model_descr=model_descr, + output_path=out_path, + opset_version=opset_version, + ) + assert os.path.exists(out_path) + assert isinstance(ret_val, v0_5.OnnxWeightsDescr) + assert ret_val.opset_version == opset_version + assert ret_val.source == out_path + + model_descr.weights.onnx = ret_val + summary = test_model(model_descr, weight_format="onnx") + assert summary.status == "passed", summary.display() + + +@pytest.mark.skip() +def test_keras_to_tensorflow(any_keras_model: Path, tmp_path: Path): + from bioimageio.core import test_model + from bioimageio.core.weight_converters.keras_to_tensorflow import convert + + out_path = tmp_path / "weights.zip" + model_descr = load_description(any_keras_model) + ret_val = convert(model_descr, out_path) + + assert out_path.exists() + assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr) + + expected_names = {"saved_model.pb", "variables/variables.index"} + with zipfile.ZipFile(out_path, "r") as f: + names = set([name for name in f.namelist()]) + assert len(expected_names - names) == 0 + + model_descr.weights.keras = ret_val + summary = test_model(model_descr, weight_format="keras_hdf5") + assert summary.status == "passed", summary.display() + + +# TODO: add tensorflow_to_keras converter +# def test_tensorflow_to_keras(any_tensorflow_model: Path, tmp_path: Path): +# from bioimageio.core.weight_converters.tensorflow_to_keras import convert + +# model_descr = load_description(any_tensorflow_model) +# out_path = tmp_path / "weights.h5" +# ret_val = convert(model_descr, output_path=out_path) +# assert out_path.exists() +# assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr) +# assert ret_val.source == out_path + +# model_descr.weights.keras = ret_val +# summary = test_model(model_descr, weight_format="keras_hdf5") +# assert summary.status == "passed", summary.display() + + +# @pytest.mark.skip() +# def test_tensorflow_to_keras_zipped(any_tensorflow_model: Path, tmp_path: Path): +# from bioimageio.core.weight_converters.tensorflow_to_keras import convert + +# out_path = tmp_path / "weights.zip" +# model_descr = load_description(any_tensorflow_model) +# ret_val = convert(model_descr, out_path) + +# assert out_path.exists() +# assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr) + +# expected_names = {"saved_model.pb", "variables/variables.index"} +# with zipfile.ZipFile(out_path, "r") as f: +# names = set([name for name in f.namelist()]) +# assert len(expected_names - names) == 0 + +# model_descr.weights.keras = ret_val +# summary = test_model(model_descr, weight_format="keras_hdf5") +# assert summary.status == "passed", summary.display() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..f9116fa5 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,21 @@ +"""utils to test bioimageio.core""" + +import os +from typing import Any, Protocol, Sequence + +import pytest + + +class ParameterSet(Protocol): + def __init__(self, values: Sequence[Any], marks: Any, id: str) -> None: + super().__init__() + + +class test_func(Protocol): + def __call__(*args: Any, **kwargs: Any): ... + + +expensive_test = pytest.mark.skipif( + os.getenv("RUN_EXPENSIVE_TESTS") != "true", + reason="Skipping expensive test (enable by RUN_EXPENSIVE_TESTS='true')", +) diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py deleted file mode 100644 index 65c93f60..00000000 --- a/tests/weight_converter/keras/test_tensorflow.py +++ /dev/null @@ -1,52 +0,0 @@ -# type: ignore # TODO enable type checking -import zipfile -from pathlib import Path - -import pytest - -from bioimageio.spec import load_description -from bioimageio.spec.model.v0_5 import ModelDescr - - -@pytest.mark.skip( - "tensorflow converter not updated yet" -) # TODO: test tensorflow converter -def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): - from bioimageio.core.weight_converter.keras import ( - convert_weights_to_tensorflow_saved_model_bundle, - ) - - out_path = tmp_path / "weights" - model = load_description(any_keras_model) - assert isinstance(model, ModelDescr), model.validation_summary.format() - ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) - assert out_path.exists() - assert (out_path / "variables").exists() - assert (out_path / "saved_model.pb").exists() - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes - - -@pytest.mark.skip( - "tensorflow converter not updated yet" -) # TODO: test tensorflow converter -def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): - from bioimageio.core.weight_converter.keras import ( - convert_weights_to_tensorflow_saved_model_bundle, - ) - - out_path = tmp_path / "weights.zip" - model = load_description(any_keras_model) - assert isinstance(model, ModelDescr), model.validation_summary.format() - ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) - assert out_path.exists() - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes - - # make sure that the zip package was created correctly - expected_names = {"saved_model.pb", "variables/variables.index"} - with zipfile.ZipFile(out_path, "r") as f: - names = set([name for name in f.namelist()]) - assert len(expected_names - names) == 0 diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py deleted file mode 100644 index 54f2cdf4..00000000 --- a/tests/weight_converter/torch/test_onnx.py +++ /dev/null @@ -1,18 +0,0 @@ -# type: ignore # TODO enable type checking -import os -from pathlib import Path - -import pytest - - -@pytest.mark.skip("onnx converter not updated yet") # TODO: test onnx converter -def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path): - from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx - - out_path = tmp_path / "weights.onnx" - ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3) - assert os.path.exists(out_path) - if not pytest.skip_onnx: - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py deleted file mode 100644 index e0cee3d8..00000000 --- a/tests/weight_converter/torch/test_torchscript.py +++ /dev/null @@ -1,22 +0,0 @@ -# type: ignore # TODO enable type checking -from pathlib import Path - -import pytest - -from bioimageio.spec.model import v0_4, v0_5 - - -@pytest.mark.skip( - "torchscript converter not updated yet" -) # TODO: test torchscript converter -def test_torchscript_converter( - any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path -): - from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript - - out_path = tmp_path / "weights.pt" - ret_val = convert_weights_to_torchscript(any_torch_model, out_path) - assert out_path.exists() - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes