Skip to content

Commit 8b3b6b1

Browse files
Merge pull request #247 from bioimage-io/adapt-new-version
Adapt to changes in the version object
2 parents 59a3d21 + 74ef93f commit 8b3b6b1

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from typing import List, Optional, Sequence
3+
from marshmallow import missing
34

45
# by default, we use the keras integrated with tensorflow
56
try:
@@ -18,10 +19,11 @@
1819

1920
class KerasModelAdapter(ModelAdapter):
2021
def _load(self, *, devices: Optional[Sequence[str]] = None) -> None:
21-
try:
22-
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
23-
except AttributeError:
22+
model_tf_version = self.bioimageio_model.weights["keras_hdf5"].tensorflow_version
23+
if model_tf_version is missing:
2424
model_tf_version = None
25+
else:
26+
model_tf_version = (int(model_tf_version.major), int(model_tf_version.minor))
2527

2628
if TF_VERSION is None or model_tf_version is None:
2729
warnings.warn("Could not check tensorflow versions. The prediction results may be wrong.")

bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import numpy as np
66
import tensorflow as tf
77
import xarray as xr
8+
from marshmallow import missing
89

9-
from bioimageio.core.resource_io import nodes
1010
from ._model_adapter import ModelAdapter
1111

1212
try:
@@ -35,10 +35,11 @@ def _load_model(self, weight_file):
3535
return str(weight_file)
3636

3737
def _load(self, *, devices: Optional[List[str]] = None):
38-
try:
39-
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
40-
except AttributeError:
38+
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version
39+
if model_tf_version is missing:
4140
model_tf_version = None
41+
else:
42+
model_tf_version = (int(model_tf_version.major), int(model_tf_version.minor))
4243

4344
tf_version = tf.__version__
4445
tf_major_and_minor = tuple(map(int, tf_version.split(".")))[:2]

bioimageio/core/weight_converter/keras/tensorflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def convert_weights_to_tensorflow_saved_model_bundle(
9393
weight_path = str(weight_spec.source)
9494

9595
if weight_spec.tensorflow_version:
96-
model_tf_major_ver = weight_spec.tensorflow_version.version[0]
96+
model_tf_major_ver = int(weight_spec.tensorflow_version.major)
9797
if model_tf_major_ver != tf_major_ver:
9898
raise RuntimeError(f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}")
9999

0 commit comments

Comments
 (0)