diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 545ddf9c..0bafb718 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -1,9 +1,10 @@ name: Docs on: - push: - branches: [ main ] - pull_request: + workflow_run: + workflows: [linux_cuda_wheels] + types: + - completed defaults: run: @@ -11,7 +12,7 @@ defaults: jobs: build: - runs-on: ubuntu-latest + runs-on: linux.g5.4xlarge.nvidia.gpu strategy: fail-fast: false steps: @@ -23,19 +24,24 @@ jobs: auto-update-conda: true miniconda-version: "latest" activate-environment: test - python-version: '3.12' + python-version: '3.9' - name: Update pip run: python -m pip install --upgrade pip - - name: Install dependencies and FFmpeg + - name: Download wheel + uses: actions/download-artifact@v3 + with: + name: pytorch_torchcodec__3.9_cu124_x86_64 + path: pytorch/torchcodec/dist/ + - name: Install torchcodec from the wheel run: | - # TODO: torchvision and torchaudio shouldn't be needed. They were only added - # to silence an error as seen in https://github.com/pytorch/torchcodec/issues/203 - python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu - conda install "ffmpeg=7.0.1" pkg-config -c conda-forge - ffmpeg -version - - name: Build and install torchcodec + wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` + echo Installing $wheel_path + ${CONDA_RUN} python -m pip install $wheel_path -vvv + - name: Install FFMPEG and other deps run: | - python -m pip install -e ".[dev]" --no-build-isolation -vvv + conda install cuda-nvrtc=12.4 libnpp -c nvidia + conda install ffmpeg=7 -c conda-forge + ffmpeg -version - name: Install doc dependencies run: | cd docs diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index 915c5236..17272a24 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -1,4 +1,4 @@ -name: Build and test Linux CUDA wheels +name: linux_cuda_wheels on: pull_request: diff --git a/docs/source/index.rst b/docs/source/index.rst index 20d6db90..3e8ed8e7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,14 @@ We achieve these capabilities through: How to sample video clips + .. grid-item-card:: :octicon:`file-code;1em` + GPU decoding using TorchCodec + :img-top: _static/img/card-background.svg + :link: generated_examples/basic_cuda_example.html + :link-type: url + + A simple example demonstrating Nvidia GPU decoding + .. toctree:: :maxdepth: 1 :caption: TorchCodec documentation diff --git a/examples/basic_cuda_example.py b/examples/basic_cuda_example.py new file mode 100644 index 00000000..f1bd103a --- /dev/null +++ b/examples/basic_cuda_example.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Accelerated video decoding on GPUs with CUDA and NVDEC +================================================================ + +.. _ndecoderec_tutorial: + +TorchCodec can use supported Nvidia hardware (see support matrix +`here `_) to speed-up +video decoding. This is called "CUDA Decoding" and it uses Nvidia's +`NVDEC hardware decoder `_ +and CUDA kernels to respectively decompress and convert to RGB. +CUDA Decoding can be faster than CPU Decoding for the actual decoding step and also for +subsequent transform steps like scaling, cropping or rotating. This is because the decode step leaves +the decoded tensor in GPU memory so the GPU doesn't have to fetch from main memory before +running the transform steps. Encoded packets are often much smaller than decoded frames so +CUDA decoding also uses less PCI-e bandwidth. + +CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios: + +#. You are decoding a large resolution video +#. You are decoding a large batch of videos that's saturating the CPU +#. You want to do whole-image transforms like scaling or convolutions on the decoded tensors + after decoding +#. Your CPU is saturated and you want to free it up for other work + + +Here are situations where CUDA Decoding may not make sense: + +#. You want bit-exact results compared to CPU Decoding +#. You have small resolution videos and the PCI-e transfer latency is large +#. Your GPU is already busy and CPU is not + +It's best to experiment with CUDA Decoding to see if it improves your use-case. With +TorchCodec you can simply pass in a device parameter to the +:class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding. + + +In order to use CUDA Decoding will need the following installed in your environment: + +#. An Nvidia GPU that supports decoding the video format you want to decode. See + the support matrix `here `_ +#. `CUDA-enabled pytorch `_ +#. FFmpeg binaries that support NdecoderEC-enabled codecs +#. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit) + + +FFmpeg versions 5, 6 and 7 from conda-forge are built with NdecoderEC support and you can +install them with conda. For example, to install FFmpeg version 7: + + +.. code-block:: bash + + conda install ffmpeg=7 -c conda-forge + conda install libnpp cuda-nvrtc -c nvidia + + +""" + +# %% +# Checking if Pytorch has CUDA enabled +# ------------------------------------- +# +# .. note:: +# +# This tutorial requires FFmpeg libraries compiled with CUDA support. +# +# +import torch + +print(f"{torch.__version__=}") +print(f"{torch.cuda.is_available()=}") +print(f"{torch.cuda.get_device_properties(0)=}") + + +# %% +# Downloading the video +# ------------------------------------- +# +# We will use the following video which has the following properties; +# +# - Codec: H.264 +# - Resolution: 960x540 +# - FPS: 29.97 +# - Pixel format: YUV420P +# +# .. raw:: html +# +# +import urllib.request + +video_file = "video.mp4" +urllib.request.urlretrieve( + "https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4", + video_file, +) + + +# %% +# CUDA Decoding using VideoDecoder +# ------------------------------------- +# +# To use CUDA decoder, you need to pass in a cuda device to the decoder. +# +from torchcodec.decoders import VideoDecoder + +decoder = VideoDecoder(video_file, device="cuda") +frame = decoder[0] + +# %% +# +# The video frames are decoded and returned as tensor of NCHW format. + +print(frame.data.shape, frame.data.dtype) + +# %% +# +# The video frames are left on the GPU memory. + +print(frame.data.device) + + +# %% +# Visualizing Frames +# ------------------------------------- +# +# Let's look at the frames decoded by CUDA decoder and compare them +# against equivalent results from the CPU decoders. +import matplotlib.pyplot as plt + + +def get_frames(timestamps: list[float], device: str): + decoder = VideoDecoder(video_file, device=device) + return [decoder.get_frame_played_at(ts) for ts in timestamps] + + +def get_numpy_images(frames): + numpy_images = [] + for frame in frames: + # We transfer to the CPU so they can be visualized by matplotlib. + numpy_image = frame.data.to("cpu").permute(1, 2, 0).numpy() + numpy_images.append(numpy_image) + return numpy_images + + +timestamps = [12, 19, 45, 131, 180] +cpu_frames = get_frames(timestamps, device="cpu") +cuda_frames = get_frames(timestamps, device="cuda:0") +cpu_numpy_images = get_numpy_images(cpu_frames) +cuda_numpy_images = get_numpy_images(cuda_frames) + + +def plot_cpu_and_cuda_images(): + n_rows = len(timestamps) + fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0]) + for i in range(n_rows): + axes[i][0].imshow(cpu_numpy_images[i]) + axes[i][1].imshow(cuda_numpy_images[i]) + + axes[0][0].set_title("CPU decoder") + axes[0][1].set_title("CUDA decoder") + plt.setp(axes, xticks=[], yticks=[]) + plt.tight_layout() + + +plot_cpu_and_cuda_images() + +# %% +# +# They look visually similar to the human eye but there may be subtle +# differences because CUDA math is not bit-exact with respect to CPU math. +# +first_cpu_frame = cpu_frames[0].data.to("cpu") +first_cuda_frame = cuda_frames[0].data.to("cpu") +frames_equal = torch.equal(first_cpu_frame, first_cuda_frame) +print(f"{frames_equal=}")