From 981d94ffcbf3417b277634f59cc2c5cb914256f7 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Wed, 26 Feb 2025 01:56:45 +0000 Subject: [PATCH] Enable Intel GPU support in torchcodec on Linux (xpu device) This commit enables support for Intel GPUs in torchcodec. It adds: * ffmpeg-vaapi for decoding * VAAPI based color space conversion (decoding output to RGBA) * RGBA surface import as torch tensor (on torch xpu device) * RGBA to RGB24 tensor slicing To build torchcodec with Intel GPU support: * Install pytorch with XPU backend support. For example, with: ``` pip3 install torch --index-url https://download.pytorch.org/whl/xpu ``` * Install oneAPI development environment following https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support * Build and install FFmpeg with `--enable-vaapi` * Install torcheval (for tests): `pip3 install torcheval` * Build torchcodec with: `ENABLE_XPU=1 python3 setup.py devel` Notes: * RGB24 is not supported color format on current Intel GPUs (as it is considered to be suboptimal due to odd alignments) * Intel media and compute APIs can't seamlessly work with the memory from each other. For example, Intel computes's Unified Shared Memory pointers are not recognized by media APIs. Thus, lower level sharing via dma fds is needed. This alos makes this part of the solution OS dependent. * Color space conversion algoriths might be quite different as it happens for Intel. This requires to check PSNR values instead of per-pixel atol/rtol differences. * Installing oneAPI environment is neded due to https://github.com/pytorch/pytorch/issues/149075 This commit was primary verfied on Intel Battlemage G21 (0xe20b) and Intel Data Center GPU Flex (0x56c0). Signed-off-by: Dmitry Rogozhkin --- setup.py | 2 + src/torchcodec/decoders/_core/CMakeLists.txt | 21 +- .../decoders/_core/CPUOnlyDevice.cpp | 31 ++ .../decoders/_core/DeviceInterface.h | 19 + .../decoders/_core/VideoDecoder.cpp | 34 +- .../decoders/_core/VideoDecoderOps.cpp | 5 +- src/torchcodec/decoders/_core/XpuDevice.cpp | 444 ++++++++++++++++++ test/conftest.py | 21 +- test/decoders/test_decoders.py | 42 +- test/decoders/test_ops.py | 44 +- test/utils.py | 22 +- 11 files changed, 634 insertions(+), 51 deletions(-) create mode 100644 src/torchcodec/decoders/_core/XpuDevice.cpp diff --git a/setup.py b/setup.py index f16521764..2c85ceb66 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") enable_cuda = os.environ.get("ENABLE_CUDA", "") + enable_xpu = os.environ.get("ENABLE_XPU", "") python_version = sys.version_info cmake_args = [ f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}", @@ -120,6 +121,7 @@ def _build_all_extensions_with_cmake(self): f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", f"-DENABLE_CUDA={enable_cuda}", + f"-DENABLE_XPU={enable_xpu}", ] Path(self.build_temp).mkdir(parents=True, exist_ok=True) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index f0a8568fe..173445308 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -9,6 +9,15 @@ find_package(Torch REQUIRED) find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +if(ENABLE_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_CUDA") +endif() +if(ENABLE_XPU) + find_package(PkgConfig REQUIRED) + pkg_check_modules(L0 REQUIRED IMPORTED_TARGET level-zero) + pkg_check_modules(LIBVA REQUIRED IMPORTED_TARGET libva) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_XPU") +endif() function(make_torchcodec_sublibrary library_name @@ -61,12 +70,15 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp FFMPEGCommon.cpp VideoDecoder.cpp + CPUOnlyDevice.cpp ) if(ENABLE_CUDA) list(APPEND decoder_sources CudaDevice.cpp) - else() - list(APPEND decoder_sources CPUOnlyDevice.cpp) + endif() + + if(ENABLE_XPU) + list(APPEND decoder_sources XpuDevice.cpp) endif() set(decoder_library_dependencies @@ -81,6 +93,11 @@ function(make_torchcodec_libraries ) endif() + if(ENABLE_XPU) + list(APPEND decoder_library_dependencies + PkgConfig::L0 PkgConfig::LIBVA) + endif() + make_torchcodec_sublibrary( "${decoder_library_name}" SHARED diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 817461099..9c5c0d297 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -14,6 +14,7 @@ namespace facebook::torchcodec { TORCH_CHECK(false, "Unsupported device: " + device.str()); } +#ifndef ENABLE_CUDA void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, @@ -40,5 +41,35 @@ std::optional findCudaCodec( [[maybe_unused]] const AVCodecID& codecId) { throwUnsupportedDeviceError(device); } +#endif // ENABLE_CUDA + +#ifndef ENABLE_XPU +void convertAVFrameToFrameOutputOnXpu( + const torch::Device& device, + [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] UniqueAVFrame& avFrame, + [[maybe_unused]] VideoDecoder::FrameOutput& frameOutput, + [[maybe_unused]] std::optional preAllocatedOutputTensor) { + throwUnsupportedDeviceError(device); +} + +void initializeContextOnXpu( + const torch::Device& device, + [[maybe_unused]] AVCodecContext* codecContext) { + throwUnsupportedDeviceError(device); +} + +void releaseContextOnXpu( + const torch::Device& device, + [[maybe_unused]] AVCodecContext* codecContext) { + throwUnsupportedDeviceError(device); +} + +std::optional findXpuCodec( + const torch::Device& device, + [[maybe_unused]] const AVCodecID& codecId) { + throwUnsupportedDeviceError(device); +} +#endif // ENABLE_XPU } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 49aea8024..6095540f4 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -29,6 +29,10 @@ void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); +void initializeContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext); + void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, @@ -36,12 +40,27 @@ void convertAVFrameToFrameOutputOnCuda( VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); +void convertAVFrameToFrameOutputOnXpu( + const torch::Device& device, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + VideoDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); + void releaseContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); +void releaseContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext); + std::optional findCudaCodec( const torch::Device& device, const AVCodecID& codecId); +std::optional findXpuCodec( + const torch::Device& device, + const AVCodecID& codecId); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 654e9a2e8..f79cd4f24 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -101,6 +101,8 @@ VideoDecoder::~VideoDecoder() { if (device.type() == torch::kCPU) { } else if (device.type() == torch::kCUDA) { releaseContextOnCuda(device, streamInfo.codecContext.get()); + } else if (device.type() == torch::kXPU) { + releaseContextOnXpu(device, streamInfo.codecContext.get()); } else { TORCH_CHECK(false, "Invalid device type: " + device.str()); } @@ -429,10 +431,16 @@ void VideoDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - findCudaCodec(device, streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (device.type() == torch::kCUDA) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + findCudaCodec(device, streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } else if (device.type() == torch::kXPU) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + findXpuCodec(device, streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); @@ -447,8 +455,12 @@ void VideoDecoder::addStream( streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; // TODO_CODE_QUALITY same as above. - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - initializeContextOnCuda(device, codecContext); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (device.type() == torch::kCUDA) { + initializeContextOnCuda(device, codecContext); + } else if (device.type() == torch::kXPU) { + initializeContextOnXpu(device, codecContext); + } } retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); @@ -476,7 +488,8 @@ void VideoDecoder::addVideoStream( const VideoStreamOptions& videoStreamOptions) { TORCH_CHECK( videoStreamOptions.device.type() == torch::kCPU || - videoStreamOptions.device.type() == torch::kCUDA, + videoStreamOptions.device.type() == torch::kCUDA || + videoStreamOptions.device.type() == torch::kXPU, "Invalid device type: " + videoStreamOptions.device.str()); addStream( @@ -1226,6 +1239,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( avFrame, frameOutput, preAllocatedOutputTensor); + } else if (streamInfo.videoStreamOptions.device.type() == torch::kXPU) { + convertAVFrameToFrameOutputOnXpu( + streamInfo.videoStreamOptions.device, + streamInfo.videoStreamOptions, + avFrame, + frameOutput, + preAllocatedOutputTensor); } else { TORCH_CHECK( false, diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index bd142d70e..fb822a9bc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -210,10 +210,13 @@ void _add_video_stream( } else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda" std::string deviceStr(device.value()); videoStreamOptions.device = torch::Device(deviceStr); + } else if (device.value().rfind("xpu", 0) == 0) { // starts with "xpu" + std::string deviceStr(device.value()); + videoStreamOptions.device = torch::Device(deviceStr); } else { throw std::runtime_error( "Invalid device=" + std::string(device.value()) + - ". device must be either cpu or cuda."); + ". device must be either cpu, cuda or xpu."); } } diff --git a/src/torchcodec/decoders/_core/XpuDevice.cpp b/src/torchcodec/decoders/_core/XpuDevice.cpp new file mode 100644 index 000000000..52d59dd56 --- /dev/null +++ b/src/torchcodec/decoders/_core/XpuDevice.cpp @@ -0,0 +1,444 @@ +#include + +#include +#include + +#include +#include + +#include "src/torchcodec/decoders/_core/DeviceInterface.h" +#include "src/torchcodec/decoders/_core/FFMPEGCommon.h" +#include "src/torchcodec/decoders/_core/VideoDecoder.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +const int MAX_XPU_GPUS = 128; +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. +// Set to a positive number to have a cache of that size. +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; +std::vector g_cached_hw_device_ctxs[MAX_XPU_GPUS]; +std::mutex g_cached_hw_device_mutexes[MAX_XPU_GPUS]; + +torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) { + torch::DeviceIndex deviceIndex = device.index(); + deviceIndex = std::max(deviceIndex, 0); + TORCH_CHECK(deviceIndex >= 0, "Device index out of range"); + // For single GPU- machines libtorch returns -1 for the device index. So for + // that case we set the device index to 0. + return deviceIndex; +} + +void addToCacheIfCacheHasCapacity( + const torch::Device& device, + AVCodecContext* codecContext) { + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + if (static_cast(deviceIndex) >= MAX_XPU_GPUS) { + return; + } + std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]); + if (MAX_CONTEXTS_PER_GPU_IN_CACHE >= 0 && + g_cached_hw_device_ctxs[deviceIndex].size() >= + MAX_CONTEXTS_PER_GPU_IN_CACHE) { + return; + } + g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx); + codecContext->hw_device_ctx = nullptr; +} + +AVBufferRef* getFromCache(const torch::Device& device) { + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + if (static_cast(deviceIndex) >= MAX_XPU_GPUS) { + return nullptr; + } + std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]); + if (g_cached_hw_device_ctxs[deviceIndex].size() > 0) { + AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[deviceIndex].back(); + g_cached_hw_device_ctxs[deviceIndex].pop_back(); + return hw_device_ctx; + } + return nullptr; +} + +AVBufferRef* getVaapiContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("vaapi"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find vaapi device"); + torch::DeviceIndex nonNegativeDeviceIndex = + getFFMPEGCompatibleDeviceIndex(device); + + AVBufferRef* hw_device_ctx = getFromCache(device); + if (hw_device_ctx != nullptr) { + return hw_device_ctx; + } + + std::string renderD = "/dev/dri/renderD128"; + + sycl::device syclDevice = c10::xpu::get_raw_device(nonNegativeDeviceIndex); + if (syclDevice.has(sycl::aspect::ext_intel_pci_address)) { + auto BDF = + syclDevice.get_info(); + renderD = "/dev/dri/by-path/pci-" + BDF + "-render"; + } + + int err = + av_hwdevice_ctx_create(&hw_device_ctx, type, renderD.c_str(), nullptr, 0); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device: ", + getFFMPEGErrorStringFromErrorCode(err)); + } + return hw_device_ctx; +} + +void throwErrorIfNonXpuDevice(const torch::Device& device) { + TORCH_CHECK( + device.type() != torch::kCPU, + "Device functions should only be called if the device is not CPU.") + if (device.type() != torch::kXPU) { + throw std::runtime_error("Unsupported device: " + device.str()); + } +} +} // namespace + +void releaseContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext) { + throwErrorIfNonXpuDevice(device); + addToCacheIfCacheHasCapacity(device, codecContext); +} + +void initializeContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext) { + throwErrorIfNonXpuDevice(device); + // It is important for pytorch itself to create the xpu context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the xpu context. + torch::Tensor dummyTensorForXpuInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); + codecContext->hw_device_ctx = getVaapiContext(device); + return; +} + +struct vaapiSurface { + vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height); + + ~vaapiSurface() { + vaDestroySurfaces(dpy_, &id_, 1); + } + + inline VASurfaceID id() const { + return id_; + } + + torch::Tensor toTensor(const torch::Device& device); + + private: + VADisplay dpy_; + VASurfaceID id_; +}; + +vaapiSurface::vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height) + : dpy_(dpy) { + VASurfaceAttrib attrib{}; + + attrib.type = VASurfaceAttribPixelFormat; + attrib.flags = VA_SURFACE_ATTRIB_SETTABLE; + attrib.value.type = VAGenericValueTypeInteger; + attrib.value.value.i = VA_FOURCC_RGBX; + + VAStatus res = vaCreateSurfaces( + dpy_, VA_RT_FORMAT_RGB32, width, height, &id_, 1, &attrib, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI surface: ", + vaErrorStr(res)); +} + +void deleter(DLManagedTensor* self) { + std::unique_ptr tensor(self); + std::unique_ptr context( + (ze_context_handle_t*)self->manager_ctx); + zeMemFree(*context, self->dl_tensor.data); +} + +torch::Tensor vaapiSurface::toTensor(const torch::Device& device) { + VADRMPRIMESurfaceDescriptor desc{}; + + VAStatus sts = vaExportSurfaceHandle( + dpy_, + id_, + VA_SURFACE_ATTRIB_MEM_TYPE_DRM_PRIME_2, + VA_EXPORT_SURFACE_READ_ONLY, + &desc); + TORCH_CHECK( + sts == VA_STATUS_SUCCESS, + "vaExportSurfaceHandle failed: ", + vaErrorStr(sts)); + + TORCH_CHECK(desc.num_objects == 1, "Expected 1 fd, got ", desc.num_objects); + TORCH_CHECK(desc.num_layers == 1, "Expected 1 layer, got ", desc.num_layers); + TORCH_CHECK( + desc.layers[0].num_planes == 1, + "Expected 1 plane, got ", + desc.num_layers); + + std::unique_ptr ze_context = + std::make_unique(); + ze_device_handle_t ze_device{}; + sycl::queue queue = c10::xpu::getCurrentXPUStream(device.index()); + + queue + .submit([&](sycl::handler& cgh) { + cgh.host_task([&](const sycl::interop_handle& ih) { + *ze_context = + ih.get_native_context(); + ze_device = + ih.get_native_device(); + }); + }) + .wait(); + + ze_external_memory_import_fd_t import_fd_desc{}; + import_fd_desc.stype = ZE_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMPORT_FD; + import_fd_desc.flags = ZE_EXTERNAL_MEMORY_TYPE_FLAG_DMA_BUF; + import_fd_desc.fd = desc.objects[0].fd; + + ze_device_mem_alloc_desc_t alloc_desc{}; + alloc_desc.pNext = &import_fd_desc; + void* usm_ptr = nullptr; + + ze_result_t res = zeMemAllocDevice( + *ze_context, &alloc_desc, desc.objects[0].size, 0, ze_device, &usm_ptr); + TORCH_CHECK( + res == ZE_RESULT_SUCCESS, "Failed to import fd=", desc.objects[0].fd); + + close(desc.objects[0].fd); + + std::unique_ptr dl_dst = std::make_unique(); + int64_t shape[3] = {desc.height, desc.width, 4}; + + dl_dst->manager_ctx = ze_context.release(); + dl_dst->deleter = deleter; + dl_dst->dl_tensor.data = usm_ptr; + dl_dst->dl_tensor.device.device_type = kDLOneAPI; + dl_dst->dl_tensor.device.device_id = device.index(); + dl_dst->dl_tensor.ndim = 3; + dl_dst->dl_tensor.dtype.code = kDLUInt; + dl_dst->dl_tensor.dtype.bits = 8; + dl_dst->dl_tensor.dtype.lanes = 1; + dl_dst->dl_tensor.shape = shape; + dl_dst->dl_tensor.strides = nullptr; + dl_dst->dl_tensor.byte_offset = desc.layers[0].offset[0]; + + auto dst = at::fromDLPack(dl_dst.release()); + + return dst; +} + +VADisplay getVaDisplayFromAV(UniqueAVFrame& avFrame) { + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; + AVHWDeviceContext* hwdc = hwfc->device_ctx; + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; + return vactx->display; +} + +struct vaapiVpContext { + VADisplay dpy_; + VAConfigID config_id_ = VA_INVALID_ID; + VAContextID context_id_ = VA_INVALID_ID; + VABufferID pipeline_buf_id_ = VA_INVALID_ID; + + // These structures must be available thru all life + // circle of the struct since they are reused by the media + // driver internally during vaRenderPicture(). + VAProcPipelineParameterBuffer pipeline_{}; + VARectangle surface_region_{}; + + vaapiVpContext() = delete; + vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height); + + ~vaapiVpContext() { + if (pipeline_buf_id_ != VA_INVALID_ID) + vaDestroyBuffer(dpy_, pipeline_buf_id_); + if (context_id_ != VA_INVALID_ID) + vaDestroyContext(dpy_, context_id_); + if (config_id_ != VA_INVALID_ID) + vaDestroyConfig(dpy_, config_id_); + } + + void convertTo(VASurfaceID id); +}; + +vaapiVpContext::vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height) + : dpy_(dpy) { + VAStatus res = vaCreateConfig( + dpy_, VAProfileNone, VAEntrypointVideoProc, nullptr, 0, &config_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI config: ", + vaErrorStr(res)); + + res = vaCreateContext( + dpy_, + config_id_, + width, + height, + VA_PROGRESSIVE, + nullptr, + 0, + &context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI VP context: ", + vaErrorStr(res)); + + surface_region_.width = width; + surface_region_.height = height; + + pipeline_.surface = (VASurfaceID)(uintptr_t)avFrame->data[3]; + pipeline_.surface_region = &surface_region_; + pipeline_.output_region = &surface_region_; + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) + pipeline_.surface_color_standard = VAProcColorStandardBT709; + + res = vaCreateBuffer( + dpy_, + context_id_, + VAProcPipelineParameterBufferType, + sizeof(pipeline_), + 1, + &pipeline_, + &pipeline_buf_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaCreateBuffer failed: ", vaErrorStr(res)); +} + +void vaapiVpContext::convertTo(VASurfaceID id) { + VAStatus res = vaBeginPicture(dpy_, context_id_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaBeginPicture failed: ", vaErrorStr(res)); + + res = vaRenderPicture(dpy_, context_id_, &pipeline_buf_id_, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaRenderPicture failed: ", vaErrorStr(res)); + + res = vaEndPicture(dpy_, context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaEndPicture failed: ", vaErrorStr(res)); + + res = vaSyncSurface(dpy_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaSyncSurface failed: ", vaErrorStr(res)); +} + +torch::Tensor convertAVFrameToTensor( + const torch::Device& device, + UniqueAVFrame& avFrame, + int width, + int height) { + TORCH_CHECK(height > 0, "height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + + // Allocating intermediate tensor we can convert input to with VAAPI. + // This tensor should be WxHx4 since VAAPI does not support RGB24 + // and works only with RGB32. + VADisplay va_dpy = getVaDisplayFromAV(avFrame); + // Importing tensor to VAAPI. + vaapiSurface va_surface(va_dpy, width, height); + + vaapiVpContext va_vp(va_dpy, avFrame, width, height); + va_vp.convertTo(va_surface.id()); + + return va_surface.toTensor(device); +} + +void convertAVFrameToFrameOutputOnXpu( + const torch::Device& device, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + VideoDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_VAAPI, + "Expected format to be AV_PIX_FMT_VAAPI, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int height = frameDims.height; + int width = frameDims.width; + torch::Tensor& dst = frameOutput.data; + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + auto shape = dst.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + } else { + dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device); + } + + auto start = std::chrono::high_resolution_clock::now(); + + // We convert input to the RGBX color format with VAAPI getting WxHx4 + // tensor on the output. + torch::Tensor dst_rgb4 = + convertAVFrameToTensor(device, avFrame, width, height); + dst.copy_(dst_rgb4.narrow(2, 0, 3)); + + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; +} + +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional findXpuCodec( + const torch::Device& device, + const AVCodecID& codecId) { + throwErrorIfNonXpuDevice(device); + + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_decoder(codec)) { + continue; + } + + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_VAAPI) { + return codec; + } + } + } + + return std::nullopt; +} + +} // namespace facebook::torchcodec diff --git a/test/conftest.py b/test/conftest.py index f5db4b5d6..df7b234cc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,6 +10,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "needs_cuda: mark for tests that rely on a CUDA device" ) + config.addinivalue_line( + "markers", "needs_xpu: mark for tests that rely on a XPU device" + ) def pytest_collection_modifyitems(items): @@ -19,15 +22,16 @@ def pytest_collection_modifyitems(items): out_items = [] for item in items: - # The needs_cuda mark will exist if the test was explicitly decorated - # with the @needs_cuda decorator. It will also exist if it was + # The needs_[cuda|xpu] mark will exist if the test was explicitly decorated + # with the respective @needs_* decorator. It will also exist if it was # parametrized with a parameter that has the mark: for example if a test # is parametrized with - # @pytest.mark.parametrize('device', cpu_and_cuda()) + # @pytest.mark.parametrize('device', cpu_and_accelerators()) # the "instances" of the tests where device == 'cuda' will have the # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the # mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if ( needs_cuda @@ -42,6 +46,13 @@ def pytest_collection_modifyitems(items): # those for whatever reason, we need to know. item.add_marker(pytest.mark.skip(reason="CUDA not available.")) + if ( + needs_xpu + and not torch.xpu.is_available() + and os.environ.get("FAIL_WITHOUT_XPU") is None + ): + item.add_marker(pytest.mark.skip(reason="XPU not available.")) + out_items.append(item) items[:] = out_items @@ -56,6 +67,8 @@ def prevent_leaking_rng(): builtin_rng_state = random.getstate() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() + if torch.xpu.is_available(): + xpu_rng_state = torch.xpu.get_rng_state() yield @@ -63,3 +76,5 @@ def prevent_leaking_rng(): random.setstate(builtin_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) + if torch.xpu.is_available(): + torch.xpu.set_rng_state(xpu_rng_state) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index cc47e116d..05a3ebc0f 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -18,7 +18,7 @@ from ..utils import ( assert_frames_equal, AV1_VIDEO, - cpu_and_cuda, + cpu_and_accelerators, get_ffmpeg_major_version, H265_VIDEO, in_fbcode, @@ -93,7 +93,7 @@ def test_create_fails(self): VideoDecoder(NASA_VIDEO.path, seek_mode="blah") @pytest.mark.parametrize("num_ffmpeg_threads", (1, 4)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode): decoder = VideoDecoder( @@ -143,7 +143,7 @@ def test_getitem_numpy_int(self): assert_frames_equal(ref_frame1, decoder[numpy.uint32(1)]) assert_frames_equal(ref_frame180, decoder[numpy.uint32(180)]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_slice(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -285,7 +285,7 @@ def test_getitem_slice(self, device, seek_mode): # See https://github.com/pytorch/torchcodec/issues/428 assert_frames_equal(sliced, ref) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -302,7 +302,7 @@ def test_getitem_fails(self, device, seek_mode): with pytest.raises(TypeError, match="Unsupported key type"): frame = decoder[2.3] # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_iteration(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -348,7 +348,7 @@ def test_iteration_slow(self): assert iterations == len(decoder) == 390 - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -381,7 +381,7 @@ def test_get_frame_at(self, device, seek_mode): frame9 = decoder.get_frame_at(numpy.uint32(9)) assert_frames_equal(ref_frame9, frame9.data) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_tuple_unpacking(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) @@ -392,7 +392,7 @@ def test_get_frame_at_tuple_unpacking(self, device): assert frame.pts_seconds == pts assert frame.duration_seconds == duration - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -403,7 +403,7 @@ def test_get_frame_at_fails(self, device, seek_mode): with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(10000) # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -443,7 +443,7 @@ def test_get_frames_at(self, device, seek_mode): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -457,7 +457,7 @@ def test_get_frames_at_fails(self, device, seek_mode): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_av1(self, device): if device == "cuda" and get_ffmpeg_major_version() == 4: return @@ -470,7 +470,7 @@ def test_get_frame_at_av1(self, device): assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -499,7 +499,7 @@ def test_get_frame_played_at_h265(self): ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -510,7 +510,7 @@ def test_get_frame_played_at_fails(self, device, seek_mode): with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(100.0) # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_played_at(self, device, seek_mode): @@ -549,7 +549,7 @@ def test_get_frames_played_at(self, device, seek_mode): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_played_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -563,7 +563,7 @@ def test_get_frames_played_at_fails(self, device, seek_mode): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_played_at(["bad"]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("stream_index", [0, 3, None]) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_in_range(self, stream_index, device, seek_mode): @@ -681,7 +681,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): lambda decoder: decoder.get_frames_played_in_range(0, 1).data, ), ) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode): decoder = VideoDecoder( @@ -709,7 +709,7 @@ def test_dimension_order_fails(self): VideoDecoder(NASA_VIDEO.path, dimension_order="NCDHW") @pytest.mark.parametrize("stream_index", [0, 3, None]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( @@ -848,7 +848,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): ) assert_frames_equal(all_frames.data, decoder[:]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -862,7 +862,7 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): with pytest.raises(ValueError, match="Invalid stop seconds"): frame = decoder.get_frames_played_in_range(0, 23) # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_key_frame_indices(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") key_frame_indices = decoder._get_key_frame_indices() @@ -907,7 +907,7 @@ def test_get_key_frame_indices(self, device): # TODO investigate why this fails internally. @pytest.mark.skipif(in_fbcode(), reason="Compile test fails internally.") - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_compile(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 9efb33f35..233c2555f 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -40,11 +40,12 @@ from ..utils import ( assert_frames_equal, - cpu_and_cuda, + cpu_and_accelerators, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, + needs_xpu, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, @@ -56,7 +57,7 @@ class TestVideoOps: - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -73,7 +74,7 @@ def test_seek_and_next(self, device): ) assert_frames_equal(frame_time6, reference_frame_time6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_seek_to_negative_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -85,7 +86,7 @@ def test_seek_to_negative_pts(self, device): frame0, _, _ = get_next_frame(decoder) assert_frames_equal(frame0, reference_frame0.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -109,7 +110,7 @@ def test_get_frame_at_pts(self, device): with pytest.raises(AssertionError): assert_frames_equal(next_frame, reference_frame6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -123,7 +124,7 @@ def test_get_frame_at_index(self, device): ) assert_frames_equal(frame6, reference_frame6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_with_info_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -135,7 +136,7 @@ def test_get_frame_with_info_at_index(self, device): assert pts.item() == pytest.approx(6.006, rel=1e-3) assert duration.item() == pytest.approx(0.03337, rel=1e-3) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_at_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -147,7 +148,7 @@ def test_get_frames_at_indices(self, device): assert_frames_equal(frames0and180[0], reference_frame0.to(device)) assert_frames_equal(frames0and180[1], reference_frame180.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_at_indices_unsorted_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) @@ -174,7 +175,7 @@ def test_get_frames_at_indices_unsorted_indices(self, device): with pytest.raises(AssertionError): assert_frames_equal(frames[0], frames[-1]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_by_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) @@ -202,7 +203,7 @@ def test_get_frames_by_pts(self, device): with pytest.raises(AssertionError): assert_frames_equal(frames[0], frames[-1]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_pts_apis_against_index_ref(self, device): # Non-regression test for https://github.com/pytorch/torchcodec/pull/287 # Get all frames in the video, then query all frames with all time-based @@ -257,7 +258,7 @@ def test_pts_apis_against_index_ref(self, device): ) torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_in_range(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -297,7 +298,7 @@ def test_get_frames_in_range(self, device): empty_frame, *_ = get_frames_in_range(decoder, start=5, stop=5) assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_throws_exception_at_eof(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -312,7 +313,7 @@ def test_throws_exception_at_eof(self, device): with pytest.raises(IndexError, match="no more frames"): get_frame_at_pts(decoder, seconds=1000.0) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_throws_exception_if_seek_too_far(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -321,7 +322,7 @@ def test_throws_exception_if_seek_too_far(self, device): with pytest.raises(IndexError, match="no more frames"): get_next_frame(decoder) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_compile_seek_and_next(self, device): # TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now # compilation fails because it can't handle tensors of size unknown at @@ -345,7 +346,7 @@ def get_frame1_and_frame_time6(decoder): assert_frames_equal(frame0, reference_frame0.to(device)) assert_frames_equal(frame_time6, reference_frame_time6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize( "create_from", ("file", "tensor", "bytes", "file_like_rawio", "file_like_bufferedio"), @@ -631,6 +632,19 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + @needs_xpu + def test_xpu_decoder(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device="xpu") + frame0, pts, duration = get_next_frame(decoder) + assert frame0.device.type == "xpu" + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + assert_frames_equal(frame0, reference_frame0.to("xpu")) + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) + class TestAudioOps: @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index 70f32bfbe..dd379ce7d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -22,8 +22,19 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) -def cpu_and_cuda(): - return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) +# Decorator for skipping XPU tests when XPU isn't available. The tests are +# effectively marked to be skipped in pytest_collection_modifyitems() of +# conftest.py +def needs_xpu(test_item): + return pytest.mark.needs_xpu(test_item) + + +def cpu_and_accelerators(): + return ( + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("xpu", marks=pytest.mark.needs_xpu), + ) def get_ffmpeg_major_version(): @@ -45,6 +56,13 @@ def assert_frames_equal(*args, **kwargs): ) else: torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0) + elif args[0].device.type == "xpu": + if not torch.allclose(*args, atol=0, rtol=0): + from torcheval.metrics import PeakSignalNoiseRatio + + metric = PeakSignalNoiseRatio() + metric.update(args[0], args[1]) + assert metric.compute() >= 40 else: torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) else: