diff --git a/images/tensorflow-notebook/Dockerfile b/images/tensorflow-notebook/Dockerfile index 974e0d6e41..2dcc18fce6 100644 --- a/images/tensorflow-notebook/Dockerfile +++ b/images/tensorflow-notebook/Dockerfile @@ -17,10 +17,16 @@ RUN mamba install --yes \ fix-permissions "${CONDA_DIR}" && \ fix-permissions "/home/${NB_USER}" -# Install tensorflow with pip, on x86_64 tensorflow-cpu -RUN [[ $(uname -m) = x86_64 ]] && TF_POSTFIX="-cpu" || TF_POSTFIX="" && \ - pip install --no-cache-dir \ - "tensorflow${TF_POSTFIX}" && \ +# Can't use `conda-forge` for aarch64 tensorflow: +# https://github.com/conda-forge/tensorflow-feedstock/issues/136 +RUN if [ "$(uname -m)" = "x86_64" ]; then \ + mamba install --yes \ + 'tensorflow-cpu' && \ + mamba clean --all -f -y ; \ + else \ + pip install --no-cache-dir \ + 'tensorflow'; \ + fi && \ fix-permissions "${CONDA_DIR}" && \ fix-permissions "/home/${NB_USER}" diff --git a/images/tensorflow-notebook/cuda/Dockerfile b/images/tensorflow-notebook/cuda/Dockerfile index a872e25cd6..2b6eea8ea4 100644 --- a/images/tensorflow-notebook/cuda/Dockerfile +++ b/images/tensorflow-notebook/cuda/Dockerfile @@ -17,9 +17,12 @@ RUN mamba install --yes \ fix-permissions "${CONDA_DIR}" && \ fix-permissions "/home/${NB_USER}" -# Install TensorFlow, CUDA and cuDNN with pip -RUN pip install --no-cache-dir \ - "tensorflow[and-cuda]<=2.17.1" && \ +# Install TensorFlow, CUDA and cuDNN +# Specifying `CONDA_OVERRIDE_CUDA` to use gpu version on a machine with cpu only: +# https://github.com/conda-forge/tensorflow-feedstock/issues/174 +RUN CONDA_OVERRIDE_CUDA=12.6 mamba install --yes \ + 'tensorflow-gpu' && \ + mamba clean --all -f -y && \ fix-permissions "${CONDA_DIR}" && \ fix-permissions "/home/${NB_USER}" diff --git a/tests/by_image/docker-stacks-foundation/test_packages.py b/tests/by_image/docker-stacks-foundation/test_packages.py index d8f3ce8add..73743d086f 100644 --- a/tests/by_image/docker-stacks-foundation/test_packages.py +++ b/tests/by_image/docker-stacks-foundation/test_packages.py @@ -39,6 +39,8 @@ "pytables": "tables", "scikit-image": "skimage", "scikit-learn": "sklearn", + "tensorflow-cpu": "tensorflow", + "tensorflow-gpu": "tensorflow", # R "randomforest": "randomForest", "rcurl": "RCurl",