Skip to content

Try to build CUDA wheels #5

Try to build CUDA wheels

Try to build CUDA wheels #5

name: Build Jax Wheels
on:
push:
branches: [main]
tags: ["*"]
pull_request:
# Check all PR
#env:
# SPHERICART_NO_LOCAL_DEPS: "1"
jobs:
build-jax-wheels:
runs-on: ubuntu-latest
strategy:
matrix:
cibw-arch: ["x86_64"]
python-version: ["3.11", "3.12"]
cuda-version: ["12.4"]
env:
CIBW_SKIP: cp36-* cp37-* cp38-* cp39-* cp310-*
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Github Actions Envs Setup
run: |
CUVERSION="${{ matrix.cuda-version }}"
PYTHONVERSION="${{ matrix.python-version }}"
CU_VERSION_NO_DOT=${CUVERSION//./}
echo CU_VERSION_NO_DOT=${CU_VERSION_NO_DOT} >> $GITHUB_ENV
CU_VERSION_DASH=${CUVERSION//./-}
echo CU_VERSION_DASH=${CU_VERSION_DASH} >> $GITHUB_ENV
PYTHON_VER_NO_DOT=${PYTHONVERSION//./}
echo PYTHON_VER_NO_DOT=${PYTHON_VER_NO_DOT} >> $GITHUB_ENV
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
platforms: all
# Build the custom Manylinux Docker image
- name: Build Manylinux Docker Image
run: |
docker build --no-cache \
-t manylinux2014_"${{ matrix.cibw-arch }}" \
--build-arg PYTHON_VER="${{ matrix.python-version }}" \
--build-arg PYTHON_VER_NO_DOT="${{ env.PYTHON_VER_NO_DOT }}" \
--build-arg CUDA_VER="${{ matrix.cuda-version }}" \
--build-arg CUDA_VER_NO_DOT="${{ env.CU_VERSION_NO_DOT }}" \
--build-arg CUDA_VER_DASH="${{ env.CU_VERSION_DASH }}" \
scripts/manylinux2014_"${{ matrix.cibw-arch }}"
# Set up Python environment
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "${{ matrix.python-version }}"
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.22.0
- name: Build jax wheels
run: python -m cibuildwheel ./sphericart-jax
env:
CUDA_HOME: /usr/local/cuda
CIBW_BUILD_VERBOSITY: 3
CIBW_BUILD: "cp${{ env.PYTHON_VER_NO_DOT }}-*"
CIBW_BUILD_FRONTEND: "pip; args: --no-build-isolation"
CIBW_SKIP: "*-musllinux* *-win32 *-manylinux_i686"
CIBW_ARCHS: "${{ matrix.cibw-arch }}"
CIBW_MANYLINUX_X86_64_IMAGE: "manylinux2014_x86_64"
CIBW_MANYLINUX_AARCH64_IMAGE: "manylinux2014_aarch64"
CIBW_ENVIRONMENT: >
CUDA_HOME=/usr/local/cuda
SPHERICART_ARCH_NATIVE=OFF
CIBW_REPAIR_WHEEL_COMMAND_LINUX: |
auditwheel repair --exclude libcuda.so --exclude libcuda.so.1 --exclude libcudart.so --exclude libnvToolsExt.so --exclude libnvrtc.so --exclude libnvrtc.so.12 -w {dest_dir} {wheel}
- uses: actions/upload-artifact@v4
with:
name: "sphericart-jax-py-${{ env.PYTHON_VER_NO_DOT }}-jax+cu${{ env.CU_VERSION_NO_DOT }}-${{ matrix.cibw-arch }}"
path: ./wheelhouse/*.whl