diff --git a/bioimageio/core/weight_converter/__init__.py b/bioimageio/core/weight_converter/__init__.py deleted file mode 100644 index 5f1674c9..00000000 --- a/bioimageio/core/weight_converter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""coming soon""" diff --git a/bioimageio/core/weight_converter/keras/__init__.py b/bioimageio/core/weight_converter/keras/__init__.py deleted file mode 100644 index 195b42b8..00000000 --- a/bioimageio/core/weight_converter/keras/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: update keras weight converters diff --git a/bioimageio/core/weight_converter/keras/_tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py deleted file mode 100644 index c901f458..00000000 --- a/bioimageio/core/weight_converter/keras/_tensorflow.py +++ /dev/null @@ -1,151 +0,0 @@ -# type: ignore # TODO: type -import os -import shutil -from pathlib import Path -from typing import no_type_check -from zipfile import ZipFile - -try: - import tensorflow.saved_model -except Exception: - tensorflow = None - -from bioimageio.spec._internal.io_utils import download -from bioimageio.spec.model.v0_5 import ModelDescr - - -def _zip_model_bundle(model_bundle_folder: Path): - zipped_model_bundle = model_bundle_folder.with_suffix(".zip") - - with ZipFile(zipped_model_bundle, "w") as zip_obj: - for root, _, files in os.walk(model_bundle_folder): - for filename in files: - src = os.path.join(root, filename) - zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) - - try: - shutil.rmtree(model_bundle_folder) - except Exception: - print("TensorFlow bundled model was not removed after compression") - - return zipped_model_bundle - - -# adapted from -# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 -def _convert_tf1( - keras_weight_path: Path, - output_path: Path, - input_name: str, - output_name: str, - zip_weights: bool, -): - try: - # try to build the tf model with the keras import from tensorflow - from bioimageio.core.weight_converter.keras._tensorflow import ( - keras, # type: ignore - ) - - except Exception: - # if the above fails try to export with the standalone keras - import keras - - @no_type_check - def build_tf_model(): - keras_model = keras.models.load_model(keras_weight_path) - assert tensorflow is not None - builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) - signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( - inputs={input_name: keras_model.input}, - outputs={output_name: keras_model.output}, - ) - - signature_def_map = { - tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature - } - - builder.add_meta_graph_and_variables( - keras.backend.get_session(), - [tensorflow.saved_model.tag_constants.SERVING], - signature_def_map=signature_def_map, - ) - builder.save() - - build_tf_model() - - if zip_weights: - output_path = _zip_model_bundle(output_path) - print("TensorFlow model exported to", output_path) - - return 0 - - -def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): - try: - # try to build the tf model with the keras import from tensorflow - from bioimageio.core.weight_converter.keras._tensorflow import keras - except Exception: - # if the above fails try to export with the standalone keras - import keras - - model = keras.models.load_model(keras_weight_path) - keras.models.save_model(model, output_path) - - if zip_weights: - output_path = _zip_model_bundle(output_path) - print("TensorFlow model exported to", output_path) - - return 0 - - -def convert_weights_to_tensorflow_saved_model_bundle( - model: ModelDescr, output_path: Path -): - """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. - - Adapted from - https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py - - Args: - model: The bioimageio model description - output_path: where to save the tensorflow weights. This path must not exist yet. - """ - assert tensorflow is not None - tf_major_ver = int(tensorflow.__version__.split(".")[0]) - - if output_path.suffix == ".zip": - output_path = output_path.with_suffix("") - zip_weights = True - else: - zip_weights = False - - if output_path.exists(): - raise ValueError(f"The ouptut directory at {output_path} must not exist.") - - if model.weights.keras_hdf5 is None: - raise ValueError("Missing Keras Hdf5 weights to convert from.") - - weight_spec = model.weights.keras_hdf5 - weight_path = download(weight_spec.source).path - - if weight_spec.tensorflow_version: - model_tf_major_ver = int(weight_spec.tensorflow_version.major) - if model_tf_major_ver != tf_major_ver: - raise RuntimeError( - f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" - ) - - if tf_major_ver == 1: - if len(model.inputs) != 1 or len(model.outputs) != 1: - raise NotImplementedError( - "Weight conversion for models with multiple inputs or outputs is not yet implemented." - ) - return _convert_tf1( - weight_path, - output_path, - model.inputs[0].id, - model.outputs[0].id, - zip_weights, - ) - else: - return _convert_tf2(weight_path, output_path, zip_weights) diff --git a/bioimageio/core/weight_converter/torch/__init__.py b/bioimageio/core/weight_converter/torch/__init__.py deleted file mode 100644 index 1b1ba526..00000000 --- a/bioimageio/core/weight_converter/torch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: torch weight converters diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py deleted file mode 100644 index 3935e1d1..00000000 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ /dev/null @@ -1,108 +0,0 @@ -# type: ignore # TODO: type -import warnings -from pathlib import Path -from typing import Any, List, Sequence, cast - -import numpy as np -from numpy.testing import assert_array_almost_equal - -from bioimageio.spec import load_description -from bioimageio.spec.common import InvalidDescr -from bioimageio.spec.model import v0_4, v0_5 - -from ...digest_spec import get_member_id, get_test_inputs -from ...weight_converter.torch._utils import load_torch_model - -try: - import torch -except ImportError: - torch = None - - -def add_onnx_weights( - model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", - *, - output_path: Path, - use_tracing: bool = True, - test_decimal: int = 4, - verbose: bool = False, - opset_version: "int | None" = None, -): - """Convert model weights from format 'pytorch_state_dict' to 'onnx'. - - Args: - source_model: model without onnx weights - opset_version: onnx opset version - use_tracing: whether to use tracing or scripting to export the onnx format - test_decimal: precision for testing whether the results agree - """ - if isinstance(model_spec, (str, Path)): - loaded_spec = load_description(Path(model_spec)) - if isinstance(loaded_spec, InvalidDescr): - raise ValueError(f"Bad resource description: {loaded_spec}") - if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError( - f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr" - ) - model_spec = loaded_spec - - state_dict_weights_descr = model_spec.weights.pytorch_state_dict - if state_dict_weights_descr is None: - raise ValueError( - "The provided model does not have weights in the pytorch state dict format" - ) - - assert torch is not None - with torch.no_grad(): - - sample = get_test_inputs(model_spec) - input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs] - input_tensors = [torch.from_numpy(ipt) for ipt in input_data] - model = load_torch_model(state_dict_weights_descr) - - expected_tensors = model(*input_tensors) - if isinstance(expected_tensors, torch.Tensor): - expected_tensors = [expected_tensors] - expected_outputs: List[np.ndarray[Any, Any]] = [ - out.numpy() for out in expected_tensors - ] - - if use_tracing: - torch.onnx.export( - model, - tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0], - str(output_path), - verbose=verbose, - opset_version=opset_version, - ) - else: - raise NotImplementedError - - try: - import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] - except ImportError: - msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." - warnings.warn(msg) - return - - # check the onnx model - sess = rt.InferenceSession(str(output_path)) - onnx_input_node_args = cast( - List[Any], sess.get_inputs() - ) # fixme: remove cast, try using rt.NodeArg instead of Any - onnx_inputs = { - input_name.name: inp - for input_name, inp in zip(onnx_input_node_args, input_data) - } - outputs = cast( - Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs) - ) # FIXME: remove cast - - try: - for exp, out in zip(expected_outputs, outputs): - assert_array_almost_equal(exp, out, decimal=test_decimal) - return 0 - except AssertionError as e: - msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" - warnings.warn(msg) - return 1 diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py deleted file mode 100644 index 5ca16069..00000000 --- a/bioimageio/core/weight_converter/torch/_torchscript.py +++ /dev/null @@ -1,146 +0,0 @@ -# type: ignore # TODO: type -from pathlib import Path -from typing import List, Sequence, Union - -import numpy as np -from numpy.testing import assert_array_almost_equal -from typing_extensions import Any, assert_never - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import Version - -from ._utils import load_torch_model - -try: - import torch -except ImportError: - torch = None - - -# FIXME: remove Any -def _check_predictions( - model: Any, - scripted_model: Any, - model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", - input_data: Sequence["torch.Tensor"], -): - assert torch is not None - - def _check(input_: Sequence[torch.Tensor]) -> None: - expected_tensors = model(*input_) - if isinstance(expected_tensors, torch.Tensor): - expected_tensors = [expected_tensors] - expected_outputs: List[np.ndarray[Any, Any]] = [ - out.numpy() for out in expected_tensors - ] - - output_tensors = scripted_model(*input_) - if isinstance(output_tensors, torch.Tensor): - output_tensors = [output_tensors] - outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors] - - try: - for exp, out in zip(expected_outputs, outputs): - assert_array_almost_equal(exp, out, decimal=4) - except AssertionError as e: - raise ValueError( - f"Results before and after weights conversion do not agree:\n {str(e)}" - ) - - _check(input_data) - - if len(model_spec.inputs) > 1: - return # FIXME: why don't we check multiple inputs? - - input_descr = model_spec.inputs[0] - if isinstance(input_descr, v0_4.InputTensorDescr): - if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape): - return - min_shape = input_descr.shape.min - step = input_descr.shape.step - else: - min_shape: List[int] = [] - step: List[int] = [] - for axis in input_descr.axes: - if isinstance(axis.size, v0_5.ParameterizedSize): - min_shape.append(axis.size.min) - step.append(axis.size.step) - elif isinstance(axis.size, int): - min_shape.append(axis.size) - step.append(0) - elif axis.size is None: - raise NotImplementedError( - f"Can't verify inputs that don't specify their shape fully: {axis}" - ) - elif isinstance(axis.size, v0_5.SizeReference): - raise NotImplementedError(f"Can't handle axes like '{axis}' yet") - else: - assert_never(axis.size) - - half_step = [st // 2 for st in step] - max_steps = 4 - - # check that input and output agree for decreasing input sizes - for step_factor in range(1, max_steps + 1): - slice_ = tuple( - slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) - for st in half_step - ) - this_input = [inp[slice_] for inp in input_data] - this_shape = this_input[0].shape - if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): - raise ValueError( - f"Mismatched shapes: {this_shape}. Expected at least {min_shape}" - ) - _check(this_input) - - -def convert_weights_to_torchscript( - model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], - output_path: Path, - use_tracing: bool = True, -) -> v0_5.TorchscriptWeightsDescr: - """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. - - Args: - model_descr: location of the resource for the input bioimageio model - output_path: where to save the torchscript weights - use_tracing: whether to use tracing or scripting to export the torchscript format - """ - - state_dict_weights_descr = model_descr.weights.pytorch_state_dict - if state_dict_weights_descr is None: - raise ValueError( - "The provided model does not have weights in the pytorch state dict format" - ) - - input_data = model_descr.get_input_test_arrays() - - with torch.no_grad(): - input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] - - model = load_torch_model(state_dict_weights_descr) - - # FIXME: remove Any - if use_tracing: - scripted_model: Any = torch.jit.trace(model, input_data) - else: - scripted_model: Any = torch.jit.script(model) - - _check_predictions( - model=model, - scripted_model=scripted_model, - model_spec=model_descr, - input_data=input_data, - ) - - # save the torchscript model - scripted_model.save( - str(output_path) - ) # does not support Path, so need to cast to str - - return v0_5.TorchscriptWeightsDescr( - source=output_path, - pytorch_version=Version(torch.__version__), - parent="pytorch_state_dict", - ) diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py deleted file mode 100644 index 01df0747..00000000 --- a/bioimageio/core/weight_converter/torch/_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Union - -from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download - -try: - import torch -except ImportError: - torch = None - - -# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too -# and for each weight format -def load_torch_model( # pyright: ignore[reportUnknownParameterType] - node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], -): - assert torch is not None - model = ( # pyright: ignore[reportUnknownVariableType] - PytorchModelAdapter.get_network(node) - ) - state = torch.load(download(node.source).path, map_location="cpu") - model.load_state_dict(state) # FIXME: check incompatible keys? - return model.eval() # pyright: ignore[reportUnknownVariableType] diff --git a/bioimageio/core/weight_converters.py b/bioimageio/core/weight_converters.py new file mode 100644 index 00000000..6e0d06ec --- /dev/null +++ b/bioimageio/core/weight_converters.py @@ -0,0 +1,492 @@ +# type: ignore # TODO: type +from __future__ import annotations + +import abc +from bioimageio.spec.model.v0_5 import WeightsEntryDescrBase +from typing import Any, List, Sequence, cast, Union +from typing_extensions import assert_never +import numpy as np +from numpy.testing import assert_array_almost_equal +from bioimageio.spec.model import v0_4, v0_5 +from torch.jit import ScriptModule +from bioimageio.core.digest_spec import get_test_inputs, get_member_id +from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter +import os +import shutil +from pathlib import Path +from typing import no_type_check +from zipfile import ZipFile +from bioimageio.spec._internal.version_type import Version +from bioimageio.spec._internal.io_utils import download + +try: + import torch +except ImportError: + torch = None + +try: + import tensorflow.saved_model +except Exception: + tensorflow = None + + +# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too +# and for each weight format +def load_torch_model( # pyright: ignore[reportUnknownParameterType] + node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], +): + assert torch is not None + model = ( # pyright: ignore[reportUnknownVariableType] + PytorchModelAdapter.get_network(node) + ) + state = torch.load(download(node.source).path, map_location="cpu") + model.load_state_dict(state) # FIXME: check incompatible keys? + return model.eval() # pyright: ignore[reportUnknownVariableType] + + +class WeightConverter(abc.ABC): + @abc.abstractmethod + def convert( + self, model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path + ) -> WeightsEntryDescrBase: + raise NotImplementedError + + +class Pytorch2Onnx(WeightConverter): + def __init__(self): + super().__init__() + assert torch is not None + + def convert( + self, + model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], + output_path: Path, + use_tracing: bool = True, + test_decimal: int = 4, + verbose: bool = False, + opset_version: int = 15, + ) -> v0_5.OnnxWeightsDescr: + """ + Convert model weights from the PyTorch state_dict format to the ONNX format. + + Args: + model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): + The model description object that contains the model and its weights. + output_path (Path): + The file path where the ONNX model will be saved. + use_tracing (bool, optional): + Whether to use tracing or scripting to export the ONNX format. Defaults to True. + test_decimal (int, optional): + The decimal precision for comparing the results between the original and converted models. + This is used in the `assert_array_almost_equal` function to check if the outputs match. + Defaults to 4. + verbose (bool, optional): + If True, will print out detailed information during the ONNX export process. Defaults to False. + opset_version (int, optional): + The ONNX opset version to use for the export. Defaults to 15. + + Raises: + ValueError: + If the provided model does not have weights in the PyTorch state_dict format. + ImportError: + If ONNX Runtime is not available for checking the exported ONNX model. + ValueError: + If the results before and after weights conversion do not agree. + + Returns: + v0_5.OnnxWeightsDescr: + A descriptor object that contains information about the exported ONNX weights. + """ + + state_dict_weights_descr = model_descr.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) + + assert torch is not None + with torch.no_grad(): + sample = get_test_inputs(model_descr) + input_data = [ + sample.members[get_member_id(ipt)].data.data + for ipt in model_descr.inputs + ] + input_tensors = [torch.from_numpy(ipt) for ipt in input_data] + model = load_torch_model(state_dict_weights_descr) + + expected_tensors = model(*input_tensors) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in expected_tensors + ] + + if use_tracing: + torch.onnx.export( + model, + ( + tuple(input_tensors) + if len(input_tensors) > 1 + else input_tensors[0] + ), + str(output_path), + verbose=verbose, + opset_version=opset_version, + ) + else: + raise NotImplementedError + + try: + import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] + except ImportError: + raise ImportError( + "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." + ) + + # check the onnx model + sess = rt.InferenceSession(str(output_path)) + onnx_input_node_args = cast( + List[Any], sess.get_inputs() + ) # fixme: remove cast, try using rt.NodeArg instead of Any + onnx_inputs = { + input_name.name: inp + for input_name, inp in zip(onnx_input_node_args, input_data) + } + outputs = cast( + Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs) + ) # FIXME: remove cast + + try: + for exp, out in zip(expected_outputs, outputs): + assert_array_almost_equal(exp, out, decimal=test_decimal) + except AssertionError as e: + raise ValueError( + f"Results before and after weights conversion do not agree:\n {str(e)}" + ) + + return v0_5.OnnxWeightsDescr( + source=output_path, parent="pytorch_state_dict", opset_version=opset_version + ) + + +class Pytorch2Torchscipt(WeightConverter): + def __init__(self): + super().__init__() + assert torch is not None + + def convert( + self, + model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], + output_path: Path, + use_tracing: bool = True, + ) -> v0_5.TorchscriptWeightsDescr: + """ + Convert model weights from the PyTorch `state_dict` format to TorchScript. + + Args: + model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): + The model description object that contains the model and its weights in the PyTorch `state_dict` format. + output_path (Path): + The file path where the TorchScript model will be saved. + use_tracing (bool): + Whether to use tracing or scripting to export the TorchScript format. + - `True`: Use tracing, which is recommended for models with straightforward control flow. + - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals). + + Raises: + ValueError: + If the provided model does not have weights in the PyTorch `state_dict` format. + + Returns: + v0_5.TorchscriptWeightsDescr: + A descriptor object that contains information about the exported TorchScript weights. + """ + state_dict_weights_descr = model_descr.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) + + input_data = model_descr.get_input_test_arrays() + + with torch.no_grad(): + input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] + model = load_torch_model(state_dict_weights_descr) + scripted_module: ScriptModule = ( + torch.jit.trace(model, input_data) + if use_tracing + else torch.jit.script(model) + ) + self._check_predictions( + model=model, + scripted_model=scripted_module, + model_spec=model_descr, + input_data=input_data, + ) + + scripted_module.save(str(output_path)) + + return v0_5.TorchscriptWeightsDescr( + source=output_path, + pytorch_version=Version(torch.__version__), + parent="pytorch_state_dict", + ) + + def _check_predictions( + self, + model: Any, + scripted_model: Any, + model_spec: v0_4.ModelDescr | v0_5.ModelDescr, + input_data: Sequence[torch.Tensor], + ): + assert torch is not None + + def _check(input_: Sequence[torch.Tensor]) -> None: + expected_tensors = model(*input_) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in expected_tensors + ] + + output_tensors = scripted_model(*input_) + if isinstance(output_tensors, torch.Tensor): + output_tensors = [output_tensors] + outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in output_tensors + ] + + try: + for exp, out in zip(expected_outputs, outputs): + assert_array_almost_equal(exp, out, decimal=4) + except AssertionError as e: + raise ValueError( + f"Results before and after weights conversion do not agree:\n {str(e)}" + ) + + _check(input_data) + + if len(model_spec.inputs) > 1: + return # FIXME: why don't we check multiple inputs? + + input_descr = model_spec.inputs[0] + if isinstance(input_descr, v0_4.InputTensorDescr): + if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape): + return + min_shape = input_descr.shape.min + step = input_descr.shape.step + else: + min_shape: List[int] = [] + step: List[int] = [] + for axis in input_descr.axes: + if isinstance(axis.size, v0_5.ParameterizedSize): + min_shape.append(axis.size.min) + step.append(axis.size.step) + elif isinstance(axis.size, int): + min_shape.append(axis.size) + step.append(0) + elif axis.size is None: + raise NotImplementedError( + f"Can't verify inputs that don't specify their shape fully: {axis}" + ) + elif isinstance(axis.size, v0_5.SizeReference): + raise NotImplementedError(f"Can't handle axes like '{axis}' yet") + else: + assert_never(axis.size) + + input_data = input_data[0] + max_shape = input_data.shape + max_steps = 4 + + # check that input and output agree for decreasing input sizes + for step_factor in range(1, max_steps + 1): + slice_ = tuple( + ( + slice(None) + if step_dim == 0 + else slice(0, max_dim - step_factor * step_dim, 1) + ) + for max_dim, step_dim in zip(max_shape, step) + ) + sliced_input = input_data[slice_] + if any( + sliced_dim < min_dim + for sliced_dim, min_dim in zip(sliced_input.shape, min_shape) + ): + return + _check([sliced_input]) + + +class Tensorflow2Bundled(WeightConverter): + def __init__(self): + super().__init__() + assert tensorflow is not None + + def convert( + self, model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path + ) -> v0_5.TensorflowSavedModelBundleWeightsDescr: + """ + Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format. + + This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle, + which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x + and 2.x versions, with appropriate checks to ensure compatibility. + + Adapted from: + https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py + + Args: + model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): + The bioimage.io model description containing the model's metadata and weights. + output_path (Path): + The directory where the TensorFlow SavedModel bundle will be saved. + This path must not already exist and, if necessary, will be zipped into a .zip file. + use_tracing (bool): + Placeholder argument; currently not used in this method but required to match the abstract method signature. + + Raises: + ValueError: + - If the specified `output_path` already exists. + - If the Keras HDF5 weights are missing in the model description. + RuntimeError: + If there is a mismatch between the TensorFlow version used by the model and the version installed. + NotImplementedError: + If the model has multiple inputs or outputs and TensorFlow 1.x is being used. + + Returns: + v0_5.TensorflowSavedModelBundleWeightsDescr: + A descriptor object containing information about the converted TensorFlow SavedModel bundle. + """ + assert tensorflow is not None + tf_major_ver = int(tensorflow.__version__.split(".")[0]) + + if output_path.suffix == ".zip": + output_path = output_path.with_suffix("") + zip_weights = True + else: + zip_weights = False + + if output_path.exists(): + raise ValueError(f"The ouptut directory at {output_path} must not exist.") + + if model_descr.weights.keras_hdf5 is None: + raise ValueError("Missing Keras Hdf5 weights to convert from.") + + weight_spec = model_descr.weights.keras_hdf5 + weight_path = download(weight_spec.source).path + + if weight_spec.tensorflow_version: + model_tf_major_ver = int(weight_spec.tensorflow_version.major) + if model_tf_major_ver != tf_major_ver: + raise RuntimeError( + f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" + ) + + if tf_major_ver == 1: + if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1: + raise NotImplementedError( + "Weight conversion for models with multiple inputs or outputs is not yet implemented." + ) + return self._convert_tf1( + weight_path, + output_path, + model_descr.inputs[0].id, + model_descr.outputs[0].id, + zip_weights, + ) + else: + return self._convert_tf2(weight_path, output_path, zip_weights) + + def _convert_tf2( + self, keras_weight_path: Path, output_path: Path, zip_weights: bool + ) -> v0_5.TensorflowSavedModelBundleWeightsDescr: + try: + # try to build the tf model with the keras import from tensorflow + from tensorflow import keras + except Exception: + # if the above fails try to export with the standalone keras + import keras + + model = keras.models.load_model(keras_weight_path) + keras.models.save_model(model, output_path) + + if zip_weights: + output_path = self._zip_model_bundle(output_path) + print("TensorFlow model exported to", output_path) + + return v0_5.TensorflowSavedModelBundleWeightsDescr( + source=output_path, + parent="keras_hdf5", + tensorflow_version=Version(tensorflow.__version__), + ) + + # adapted from + # https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 + def _convert_tf1( + self, + keras_weight_path: Path, + output_path: Path, + input_name: str, + output_name: str, + zip_weights: bool, + ) -> v0_5.TensorflowSavedModelBundleWeightsDescr: + try: + # try to build the tf model with the keras import from tensorflow + from tensorflow import ( + keras, # type: ignore + ) + + except Exception: + # if the above fails try to export with the standalone keras + import keras + + @no_type_check + def build_tf_model(): + keras_model = keras.models.load_model(keras_weight_path) + assert tensorflow is not None + builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) + signature = ( + tensorflow.saved_model.signature_def_utils.predict_signature_def( + inputs={input_name: keras_model.input}, + outputs={output_name: keras_model.output}, + ) + ) + + signature_def_map = { + tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature + } + + builder.add_meta_graph_and_variables( + keras.backend.get_session(), + [tensorflow.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map, + ) + builder.save() + + build_tf_model() + + if zip_weights: + output_path = self._zip_model_bundle(output_path) + print("TensorFlow model exported to", output_path) + + return v0_5.TensorflowSavedModelBundleWeightsDescr( + source=output_path, + parent="keras_hdf5", + tensorflow_version=Version(tensorflow.__version__), + ) + + def _zip_model_bundle(self, model_bundle_folder: Path): + zipped_model_bundle = model_bundle_folder.with_suffix(".zip") + + with ZipFile(zipped_model_bundle, "w") as zip_obj: + for root, _, files in os.walk(model_bundle_folder): + for filename in files: + src = os.path.join(root, filename) + zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) + + try: + shutil.rmtree(model_bundle_folder) + except Exception: + print("TensorFlow bundled model was not removed after compression") + + return zipped_model_bundle diff --git a/setup.py b/setup.py index 7aa66e16..a1a86f45 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ extras_require={ "pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"], "tensorflow": ["tensorflow", "keras>=2.15"], - "onnx": ["onnxruntime"], + "onnx": ["onnxruntime", "onnx"], "dev": [ "black", # "crick", # currently requires python<=3.9 diff --git a/tests/weight_converter/test_add_weights.py b/tests/test_add_weights.py similarity index 100% rename from tests/weight_converter/test_add_weights.py rename to tests/test_add_weights.py diff --git a/tests/test_weight_converters.py b/tests/test_weight_converters.py new file mode 100644 index 00000000..88010744 --- /dev/null +++ b/tests/test_weight_converters.py @@ -0,0 +1,69 @@ +# type: ignore # TODO enable type checking +import zipfile +from pathlib import Path + +import pytest + +import os + +from bioimageio.spec import load_description +from bioimageio.spec.model import v0_5 + +from bioimageio.core.weight_converters import ( + Pytorch2Torchscipt, + Pytorch2Onnx, + Tensorflow2Bundled, +) + + +def test_torchscript_converter(any_torch_model, tmp_path): + bio_model = load_description(any_torch_model) + out_path = tmp_path / "weights.pt" + util = Pytorch2Torchscipt() + ret_val = util.convert(bio_model, out_path) + assert out_path.exists() + assert isinstance(ret_val, v0_5.TorchscriptWeightsDescr) + assert ret_val.source == out_path + + +def test_onnx_converter(convert_to_onnx, tmp_path): + bio_model = load_description(convert_to_onnx) + out_path = tmp_path / "weights.onnx" + opset_version = 15 + util = Pytorch2Onnx() + ret_val = util.convert( + model_descr=bio_model, + output_path=out_path, + test_decimal=3, + opset_version=opset_version, + ) + assert os.path.exists(out_path) + assert isinstance(ret_val, v0_5.OnnxWeightsDescr) + assert ret_val.opset_version == opset_version + assert ret_val.source == out_path + + +def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): + model = load_description(any_keras_model) + out_path = tmp_path / "weights.h5" + util = Tensorflow2Bundled() + ret_val = util.convert(model, out_path) + assert out_path.exists() + assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr) + assert ret_val.source == out_path + + +@pytest.mark.skip() +def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): + out_path = tmp_path / "weights.zip" + model = load_description(any_keras_model) + util = Tensorflow2Bundled() + ret_val = util.convert(model, out_path) + + assert out_path.exists() + assert isinstance(ret_val, v0_5.TensorflowSavedModelBundleWeightsDescr) + + expected_names = {"saved_model.pb", "variables/variables.index"} + with zipfile.ZipFile(out_path, "r") as f: + names = set([name for name in f.namelist()]) + assert len(expected_names - names) == 0 diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py deleted file mode 100644 index 65c93f60..00000000 --- a/tests/weight_converter/keras/test_tensorflow.py +++ /dev/null @@ -1,52 +0,0 @@ -# type: ignore # TODO enable type checking -import zipfile -from pathlib import Path - -import pytest - -from bioimageio.spec import load_description -from bioimageio.spec.model.v0_5 import ModelDescr - - -@pytest.mark.skip( - "tensorflow converter not updated yet" -) # TODO: test tensorflow converter -def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): - from bioimageio.core.weight_converter.keras import ( - convert_weights_to_tensorflow_saved_model_bundle, - ) - - out_path = tmp_path / "weights" - model = load_description(any_keras_model) - assert isinstance(model, ModelDescr), model.validation_summary.format() - ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) - assert out_path.exists() - assert (out_path / "variables").exists() - assert (out_path / "saved_model.pb").exists() - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes - - -@pytest.mark.skip( - "tensorflow converter not updated yet" -) # TODO: test tensorflow converter -def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): - from bioimageio.core.weight_converter.keras import ( - convert_weights_to_tensorflow_saved_model_bundle, - ) - - out_path = tmp_path / "weights.zip" - model = load_description(any_keras_model) - assert isinstance(model, ModelDescr), model.validation_summary.format() - ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) - assert out_path.exists() - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes - - # make sure that the zip package was created correctly - expected_names = {"saved_model.pb", "variables/variables.index"} - with zipfile.ZipFile(out_path, "r") as f: - names = set([name for name in f.namelist()]) - assert len(expected_names - names) == 0 diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py deleted file mode 100644 index 54f2cdf4..00000000 --- a/tests/weight_converter/torch/test_onnx.py +++ /dev/null @@ -1,18 +0,0 @@ -# type: ignore # TODO enable type checking -import os -from pathlib import Path - -import pytest - - -@pytest.mark.skip("onnx converter not updated yet") # TODO: test onnx converter -def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path): - from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx - - out_path = tmp_path / "weights.onnx" - ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3) - assert os.path.exists(out_path) - if not pytest.skip_onnx: - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py deleted file mode 100644 index e0cee3d8..00000000 --- a/tests/weight_converter/torch/test_torchscript.py +++ /dev/null @@ -1,22 +0,0 @@ -# type: ignore # TODO enable type checking -from pathlib import Path - -import pytest - -from bioimageio.spec.model import v0_4, v0_5 - - -@pytest.mark.skip( - "torchscript converter not updated yet" -) # TODO: test torchscript converter -def test_torchscript_converter( - any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path -): - from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript - - out_path = tmp_path / "weights.pt" - ret_val = convert_weights_to_torchscript(any_torch_model, out_path) - assert out_path.exists() - assert ( - ret_val == 0 - ) # check for correctness is done in converter and returns 0 if it passes