-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathbasic_cuda_example.py
182 lines (138 loc) · 5.82 KB
/
basic_cuda_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Accelerated video decoding on GPUs with CUDA and NVDEC
================================================================
.. _ndecoderec_tutorial:
TorchCodec can use supported Nvidia hardware (see support matrix
`here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_) to speed-up
video decoding. This is called "CUDA Decoding" and it uses Nvidia's
`NVDEC hardware decoder <https://developer.nvidia.com/video-codec-sdk>`_
and CUDA kernels to respectively decompress and convert to RGB.
CUDA Decoding can be faster than CPU Decoding for the actual decoding step and also for
subsequent transform steps like scaling, cropping or rotating. This is because the decode step leaves
the decoded tensor in GPU memory so the GPU doesn't have to fetch from main memory before
running the transform steps. Encoded packets are often much smaller than decoded frames so
CUDA decoding also uses less PCI-e bandwidth.
CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios:
#. You are decoding a large resolution video
#. You are decoding a large batch of videos that's saturating the CPU
#. You want to do whole-image transforms like scaling or convolutions on the decoded tensors
after decoding
#. Your CPU is saturated and you want to free it up for other work
Here are situations where CUDA Decoding may not make sense:
#. You want bit-exact results compared to CPU Decoding
#. You have small resolution videos and the PCI-e transfer latency is large
#. Your GPU is already busy and CPU is not
It's best to experiment with CUDA Decoding to see if it improves your use-case. With
TorchCodec you can simply pass in a device parameter to the
:class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding.
In order to use CUDA Decoding will need the following installed in your environment:
#. An Nvidia GPU that supports decoding the video format you want to decode. See
the support matrix `here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_
#. `CUDA-enabled pytorch <https://pytorch.org/get-started/locally/>`_
#. FFmpeg binaries that support NdecoderEC-enabled codecs
#. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit)
FFmpeg versions 5, 6 and 7 from conda-forge are built with NdecoderEC support and you can
install them with conda. For example, to install FFmpeg version 7:
.. code-block:: bash
conda install ffmpeg=7 -c conda-forge
conda install libnpp cuda-nvrtc -c nvidia
"""
# %%
# Checking if Pytorch has CUDA enabled
# -------------------------------------
#
# .. note::
#
# This tutorial requires FFmpeg libraries compiled with CUDA support.
#
#
import torch
print(f"{torch.__version__=}")
print(f"{torch.cuda.is_available()=}")
print(f"{torch.cuda.get_device_properties(0)=}")
# %%
# Downloading the video
# -------------------------------------
#
# We will use the following video which has the following properties;
#
# - Codec: H.264
# - Resolution: 960x540
# - FPS: 29.97
# - Pixel format: YUV420P
#
# .. raw:: html
#
# <video style="max-width: 100%" controls>
# <source src="https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4" type="video/mp4">
# </video>
import urllib.request
video_file = "video.mp4"
urllib.request.urlretrieve(
"https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4",
video_file,
)
# %%
# CUDA Decoding using VideoDecoder
# -------------------------------------
#
# To use CUDA decoder, you need to pass in a cuda device to the decoder.
#
from torchcodec.decoders import VideoDecoder
decoder = VideoDecoder(video_file, device="cuda")
frame = decoder[0]
# %%
#
# The video frames are decoded and returned as tensor of NCHW format.
print(frame.data.shape, frame.data.dtype)
# %%
#
# The video frames are left on the GPU memory.
print(frame.data.device)
# %%
# Visualizing Frames
# -------------------------------------
#
# Let's look at the frames decoded by CUDA decoder and compare them
# against equivalent results from the CPU decoders.
import matplotlib.pyplot as plt
def get_frames(timestamps: list[float], device: str):
decoder = VideoDecoder(video_file, device=device)
return [decoder.get_frame_played_at(ts) for ts in timestamps]
def get_numpy_images(frames):
numpy_images = []
for frame in frames:
# We transfer to the CPU so they can be visualized by matplotlib.
numpy_image = frame.data.to("cpu").permute(1, 2, 0).numpy()
numpy_images.append(numpy_image)
return numpy_images
timestamps = [12, 19, 45, 131, 180]
cpu_frames = get_frames(timestamps, device="cpu")
cuda_frames = get_frames(timestamps, device="cuda:0")
cpu_numpy_images = get_numpy_images(cpu_frames)
cuda_numpy_images = get_numpy_images(cuda_frames)
def plot_cpu_and_cuda_images():
n_rows = len(timestamps)
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
for i in range(n_rows):
axes[i][0].imshow(cpu_numpy_images[i])
axes[i][1].imshow(cuda_numpy_images[i])
axes[0][0].set_title("CPU decoder")
axes[0][1].set_title("CUDA decoder")
plt.setp(axes, xticks=[], yticks=[])
plt.tight_layout()
plot_cpu_and_cuda_images()
# %%
#
# They look visually similar to the human eye but there may be subtle
# differences because CUDA math is not bit-exact with respect to CPU math.
#
first_cpu_frame = cpu_frames[0].data.to("cpu")
first_cuda_frame = cuda_frames[0].data.to("cpu")
frames_equal = torch.equal(first_cpu_frame, first_cuda_frame)
print(f"{frames_equal=}")