Skip to content

Commit 74f3972

Browse files
authored
Updates for Tensorflow 2.9+ & Pytorch 1.13 (#195)
* Support tensorflow>=2.9 * Support pytorch 1.13 * [CI] Test torch 1.13 & tensorflow 2.9/2.10
1 parent 768d66c commit 74f3972

File tree

4 files changed

+23
-31
lines changed

4 files changed

+23
-31
lines changed

Diff for: .github/workflows/ci.yml

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
ml-deps:
13-
- "torch==1.11.0+cpu torchvision==0.12.0+cpu torchdata==0.3.0 tensorflow-cpu==2.7.1"
14-
- "torch==1.12.1+cpu torchvision==0.13.1+cpu torchdata==0.4.1 tensorflow-cpu==2.8.1"
13+
- "torch==1.11.0+cpu torchvision==0.12.0+cpu torchdata==0.3.0 tensorflow-cpu==2.8.1"
14+
- "torch==1.12.1+cpu torchvision==0.13.1+cpu torchdata==0.4.1 tensorflow-cpu==2.9.1"
15+
- "torch==1.13.0+cpu torchvision==0.14.0+cpu torchdata==0.5.0 tensorflow-cpu==2.10.0"
1516

1617
env:
1718
run_coverage: ${{ github.ref == 'refs/heads/master' }}
1819

1920
steps:
20-
- uses: actions/checkout@v2
21+
- uses: actions/checkout@v3
2122

2223
- name: Set up Python ${{ matrix.python-version }}
23-
uses: actions/setup-python@v2
24+
uses: actions/setup-python@v4
2425
with:
2526
python-version: "3.7"
2627

Diff for: tests/models/test_tensorflow_keras_models.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
import tensorflow as tf
1313

1414
import tiledb
15-
from tiledb.ml.models.tensorflow_keras import (
16-
TensorflowKerasTileDBModel,
17-
tf_keras_is_keras,
18-
)
19-
20-
if tf_keras_is_keras:
21-
from keras import testing_utils
22-
else:
23-
from tensorflow.python.keras import testing_utils
15+
from tiledb.ml.models.tensorflow_keras import TensorflowKerasTileDBModel
2416

17+
try:
18+
from keras.testing_infra.test_utils import (
19+
get_small_functional_mlp,
20+
get_small_sequential_mlp,
21+
)
22+
except ImportError:
23+
from keras.testing_utils import get_small_functional_mlp, get_small_sequential_mlp
2524

2625
# Suppress all Tensorflow messages
2726
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -75,11 +74,7 @@ def test_load_tiledb_error_with_wrong_uri():
7574

7675
api = pytest.mark.parametrize(
7776
"api",
78-
[
79-
testing_utils.get_small_sequential_mlp,
80-
testing_utils.get_small_functional_mlp,
81-
ConfigSubclassModel,
82-
],
77+
[get_small_sequential_mlp, get_small_functional_mlp, ConfigSubclassModel],
8378
)
8479

8580

Diff for: tiledb/ml/models/tensorflow_keras.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from operator import attrgetter
99
from typing import Any, List, Mapping, Optional, Tuple
1010

11+
import keras
1112
import numpy as np
1213
import tensorflow as tf
1314

@@ -16,20 +17,9 @@
1617
from ._base import Meta, TileDBArtifact, Timestamp, current_milli_time, group_create
1718
from ._tensorboard import TensorBoardTileDB
1819

19-
try:
20-
import keras
21-
22-
if keras.Model is not tf.keras.Model:
23-
raise ImportError
24-
tf_keras_is_keras = True
25-
except ImportError:
26-
import tensorflow.python.keras as keras
27-
28-
tf_keras_is_keras = False
29-
3020
SharedObjectLoadingScope = keras.utils.generic_utils.SharedObjectLoadingScope
3121
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
32-
TFOptimizer = keras.optimizer_v1.TFOptimizer
22+
TFOptimizer = keras.optimizers.TFOptimizer
3323
get_json_type = keras.saving.saved_model.json_utils.get_json_type
3424
preprocess_weights_for_loading = keras.saving.hdf5_format.preprocess_weights_for_loading
3525
saving_utils = keras.saving.saving_utils

Diff for: tiledb/ml/readers/_pytorch_collators.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from ._tensor_schema import TensorSchema
1313
from .types import TensorKind
1414

15+
try:
16+
nested_tensor = torch.nested.nested_tensor
17+
except AttributeError:
18+
nested_tensor = getattr(torch, "nested_tensor", None)
19+
20+
1521
T = TypeVar("T")
1622

1723

@@ -101,7 +107,7 @@ def convert(self, value: np.ndarray) -> torch.Tensor:
101107

102108
def collate(self, batch: Sequence[np.ndarray]) -> torch.Tensor:
103109
if self.to_nested:
104-
return torch.nested_tensor(tuple(map(torch.from_numpy, batch)))
110+
return nested_tensor(list(map(torch.from_numpy, batch)))
105111
else:
106112
return torch.from_numpy(np.stack(batch))
107113

0 commit comments

Comments
 (0)