diff --git a/setup.py b/setup.py index f1652176..2c85ceb6 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") enable_cuda = os.environ.get("ENABLE_CUDA", "") + enable_xpu = os.environ.get("ENABLE_XPU", "") python_version = sys.version_info cmake_args = [ f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}", @@ -120,6 +121,7 @@ def _build_all_extensions_with_cmake(self): f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", f"-DENABLE_CUDA={enable_cuda}", + f"-DENABLE_XPU={enable_xpu}", ] Path(self.build_temp).mkdir(parents=True, exist_ok=True) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index f0a8568f..17344530 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -9,6 +9,15 @@ find_package(Torch REQUIRED) find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +if(ENABLE_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_CUDA") +endif() +if(ENABLE_XPU) + find_package(PkgConfig REQUIRED) + pkg_check_modules(L0 REQUIRED IMPORTED_TARGET level-zero) + pkg_check_modules(LIBVA REQUIRED IMPORTED_TARGET libva) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_XPU") +endif() function(make_torchcodec_sublibrary library_name @@ -61,12 +70,15 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp FFMPEGCommon.cpp VideoDecoder.cpp + CPUOnlyDevice.cpp ) if(ENABLE_CUDA) list(APPEND decoder_sources CudaDevice.cpp) - else() - list(APPEND decoder_sources CPUOnlyDevice.cpp) + endif() + + if(ENABLE_XPU) + list(APPEND decoder_sources XpuDevice.cpp) endif() set(decoder_library_dependencies @@ -81,6 +93,11 @@ function(make_torchcodec_libraries ) endif() + if(ENABLE_XPU) + list(APPEND decoder_library_dependencies + PkgConfig::L0 PkgConfig::LIBVA) + endif() + make_torchcodec_sublibrary( "${decoder_library_name}" SHARED diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 81746109..9c5c0d29 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -14,6 +14,7 @@ namespace facebook::torchcodec { TORCH_CHECK(false, "Unsupported device: " + device.str()); } +#ifndef ENABLE_CUDA void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, @@ -40,5 +41,35 @@ std::optional findCudaCodec( [[maybe_unused]] const AVCodecID& codecId) { throwUnsupportedDeviceError(device); } +#endif // ENABLE_CUDA + +#ifndef ENABLE_XPU +void convertAVFrameToFrameOutputOnXpu( + const torch::Device& device, + [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] UniqueAVFrame& avFrame, + [[maybe_unused]] VideoDecoder::FrameOutput& frameOutput, + [[maybe_unused]] std::optional preAllocatedOutputTensor) { + throwUnsupportedDeviceError(device); +} + +void initializeContextOnXpu( + const torch::Device& device, + [[maybe_unused]] AVCodecContext* codecContext) { + throwUnsupportedDeviceError(device); +} + +void releaseContextOnXpu( + const torch::Device& device, + [[maybe_unused]] AVCodecContext* codecContext) { + throwUnsupportedDeviceError(device); +} + +std::optional findXpuCodec( + const torch::Device& device, + [[maybe_unused]] const AVCodecID& codecId) { + throwUnsupportedDeviceError(device); +} +#endif // ENABLE_XPU } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 49aea802..6095540f 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -29,6 +29,10 @@ void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); +void initializeContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext); + void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, @@ -36,12 +40,27 @@ void convertAVFrameToFrameOutputOnCuda( VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); +void convertAVFrameToFrameOutputOnXpu( + const torch::Device& device, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + VideoDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); + void releaseContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); +void releaseContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext); + std::optional findCudaCodec( const torch::Device& device, const AVCodecID& codecId); +std::optional findXpuCodec( + const torch::Device& device, + const AVCodecID& codecId); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 654e9a2e..f79cd4f2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -101,6 +101,8 @@ VideoDecoder::~VideoDecoder() { if (device.type() == torch::kCPU) { } else if (device.type() == torch::kCUDA) { releaseContextOnCuda(device, streamInfo.codecContext.get()); + } else if (device.type() == torch::kXPU) { + releaseContextOnXpu(device, streamInfo.codecContext.get()); } else { TORCH_CHECK(false, "Invalid device type: " + device.str()); } @@ -429,10 +431,16 @@ void VideoDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - findCudaCodec(device, streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (device.type() == torch::kCUDA) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + findCudaCodec(device, streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } else if (device.type() == torch::kXPU) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + findXpuCodec(device, streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); @@ -447,8 +455,12 @@ void VideoDecoder::addStream( streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; // TODO_CODE_QUALITY same as above. - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - initializeContextOnCuda(device, codecContext); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (device.type() == torch::kCUDA) { + initializeContextOnCuda(device, codecContext); + } else if (device.type() == torch::kXPU) { + initializeContextOnXpu(device, codecContext); + } } retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); @@ -476,7 +488,8 @@ void VideoDecoder::addVideoStream( const VideoStreamOptions& videoStreamOptions) { TORCH_CHECK( videoStreamOptions.device.type() == torch::kCPU || - videoStreamOptions.device.type() == torch::kCUDA, + videoStreamOptions.device.type() == torch::kCUDA || + videoStreamOptions.device.type() == torch::kXPU, "Invalid device type: " + videoStreamOptions.device.str()); addStream( @@ -1226,6 +1239,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( avFrame, frameOutput, preAllocatedOutputTensor); + } else if (streamInfo.videoStreamOptions.device.type() == torch::kXPU) { + convertAVFrameToFrameOutputOnXpu( + streamInfo.videoStreamOptions.device, + streamInfo.videoStreamOptions, + avFrame, + frameOutput, + preAllocatedOutputTensor); } else { TORCH_CHECK( false, diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index bd142d70..fb822a9b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -210,10 +210,13 @@ void _add_video_stream( } else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda" std::string deviceStr(device.value()); videoStreamOptions.device = torch::Device(deviceStr); + } else if (device.value().rfind("xpu", 0) == 0) { // starts with "xpu" + std::string deviceStr(device.value()); + videoStreamOptions.device = torch::Device(deviceStr); } else { throw std::runtime_error( "Invalid device=" + std::string(device.value()) + - ". device must be either cpu or cuda."); + ". device must be either cpu, cuda or xpu."); } } diff --git a/src/torchcodec/decoders/_core/XpuDevice.cpp b/src/torchcodec/decoders/_core/XpuDevice.cpp new file mode 100644 index 00000000..52d59dd5 --- /dev/null +++ b/src/torchcodec/decoders/_core/XpuDevice.cpp @@ -0,0 +1,444 @@ +#include + +#include +#include + +#include +#include + +#include "src/torchcodec/decoders/_core/DeviceInterface.h" +#include "src/torchcodec/decoders/_core/FFMPEGCommon.h" +#include "src/torchcodec/decoders/_core/VideoDecoder.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +const int MAX_XPU_GPUS = 128; +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. +// Set to a positive number to have a cache of that size. +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; +std::vector g_cached_hw_device_ctxs[MAX_XPU_GPUS]; +std::mutex g_cached_hw_device_mutexes[MAX_XPU_GPUS]; + +torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) { + torch::DeviceIndex deviceIndex = device.index(); + deviceIndex = std::max(deviceIndex, 0); + TORCH_CHECK(deviceIndex >= 0, "Device index out of range"); + // For single GPU- machines libtorch returns -1 for the device index. So for + // that case we set the device index to 0. + return deviceIndex; +} + +void addToCacheIfCacheHasCapacity( + const torch::Device& device, + AVCodecContext* codecContext) { + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + if (static_cast(deviceIndex) >= MAX_XPU_GPUS) { + return; + } + std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]); + if (MAX_CONTEXTS_PER_GPU_IN_CACHE >= 0 && + g_cached_hw_device_ctxs[deviceIndex].size() >= + MAX_CONTEXTS_PER_GPU_IN_CACHE) { + return; + } + g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx); + codecContext->hw_device_ctx = nullptr; +} + +AVBufferRef* getFromCache(const torch::Device& device) { + torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + if (static_cast(deviceIndex) >= MAX_XPU_GPUS) { + return nullptr; + } + std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]); + if (g_cached_hw_device_ctxs[deviceIndex].size() > 0) { + AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[deviceIndex].back(); + g_cached_hw_device_ctxs[deviceIndex].pop_back(); + return hw_device_ctx; + } + return nullptr; +} + +AVBufferRef* getVaapiContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("vaapi"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find vaapi device"); + torch::DeviceIndex nonNegativeDeviceIndex = + getFFMPEGCompatibleDeviceIndex(device); + + AVBufferRef* hw_device_ctx = getFromCache(device); + if (hw_device_ctx != nullptr) { + return hw_device_ctx; + } + + std::string renderD = "/dev/dri/renderD128"; + + sycl::device syclDevice = c10::xpu::get_raw_device(nonNegativeDeviceIndex); + if (syclDevice.has(sycl::aspect::ext_intel_pci_address)) { + auto BDF = + syclDevice.get_info(); + renderD = "/dev/dri/by-path/pci-" + BDF + "-render"; + } + + int err = + av_hwdevice_ctx_create(&hw_device_ctx, type, renderD.c_str(), nullptr, 0); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device: ", + getFFMPEGErrorStringFromErrorCode(err)); + } + return hw_device_ctx; +} + +void throwErrorIfNonXpuDevice(const torch::Device& device) { + TORCH_CHECK( + device.type() != torch::kCPU, + "Device functions should only be called if the device is not CPU.") + if (device.type() != torch::kXPU) { + throw std::runtime_error("Unsupported device: " + device.str()); + } +} +} // namespace + +void releaseContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext) { + throwErrorIfNonXpuDevice(device); + addToCacheIfCacheHasCapacity(device, codecContext); +} + +void initializeContextOnXpu( + const torch::Device& device, + AVCodecContext* codecContext) { + throwErrorIfNonXpuDevice(device); + // It is important for pytorch itself to create the xpu context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the xpu context. + torch::Tensor dummyTensorForXpuInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); + codecContext->hw_device_ctx = getVaapiContext(device); + return; +} + +struct vaapiSurface { + vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height); + + ~vaapiSurface() { + vaDestroySurfaces(dpy_, &id_, 1); + } + + inline VASurfaceID id() const { + return id_; + } + + torch::Tensor toTensor(const torch::Device& device); + + private: + VADisplay dpy_; + VASurfaceID id_; +}; + +vaapiSurface::vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height) + : dpy_(dpy) { + VASurfaceAttrib attrib{}; + + attrib.type = VASurfaceAttribPixelFormat; + attrib.flags = VA_SURFACE_ATTRIB_SETTABLE; + attrib.value.type = VAGenericValueTypeInteger; + attrib.value.value.i = VA_FOURCC_RGBX; + + VAStatus res = vaCreateSurfaces( + dpy_, VA_RT_FORMAT_RGB32, width, height, &id_, 1, &attrib, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI surface: ", + vaErrorStr(res)); +} + +void deleter(DLManagedTensor* self) { + std::unique_ptr tensor(self); + std::unique_ptr context( + (ze_context_handle_t*)self->manager_ctx); + zeMemFree(*context, self->dl_tensor.data); +} + +torch::Tensor vaapiSurface::toTensor(const torch::Device& device) { + VADRMPRIMESurfaceDescriptor desc{}; + + VAStatus sts = vaExportSurfaceHandle( + dpy_, + id_, + VA_SURFACE_ATTRIB_MEM_TYPE_DRM_PRIME_2, + VA_EXPORT_SURFACE_READ_ONLY, + &desc); + TORCH_CHECK( + sts == VA_STATUS_SUCCESS, + "vaExportSurfaceHandle failed: ", + vaErrorStr(sts)); + + TORCH_CHECK(desc.num_objects == 1, "Expected 1 fd, got ", desc.num_objects); + TORCH_CHECK(desc.num_layers == 1, "Expected 1 layer, got ", desc.num_layers); + TORCH_CHECK( + desc.layers[0].num_planes == 1, + "Expected 1 plane, got ", + desc.num_layers); + + std::unique_ptr ze_context = + std::make_unique(); + ze_device_handle_t ze_device{}; + sycl::queue queue = c10::xpu::getCurrentXPUStream(device.index()); + + queue + .submit([&](sycl::handler& cgh) { + cgh.host_task([&](const sycl::interop_handle& ih) { + *ze_context = + ih.get_native_context(); + ze_device = + ih.get_native_device(); + }); + }) + .wait(); + + ze_external_memory_import_fd_t import_fd_desc{}; + import_fd_desc.stype = ZE_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMPORT_FD; + import_fd_desc.flags = ZE_EXTERNAL_MEMORY_TYPE_FLAG_DMA_BUF; + import_fd_desc.fd = desc.objects[0].fd; + + ze_device_mem_alloc_desc_t alloc_desc{}; + alloc_desc.pNext = &import_fd_desc; + void* usm_ptr = nullptr; + + ze_result_t res = zeMemAllocDevice( + *ze_context, &alloc_desc, desc.objects[0].size, 0, ze_device, &usm_ptr); + TORCH_CHECK( + res == ZE_RESULT_SUCCESS, "Failed to import fd=", desc.objects[0].fd); + + close(desc.objects[0].fd); + + std::unique_ptr dl_dst = std::make_unique(); + int64_t shape[3] = {desc.height, desc.width, 4}; + + dl_dst->manager_ctx = ze_context.release(); + dl_dst->deleter = deleter; + dl_dst->dl_tensor.data = usm_ptr; + dl_dst->dl_tensor.device.device_type = kDLOneAPI; + dl_dst->dl_tensor.device.device_id = device.index(); + dl_dst->dl_tensor.ndim = 3; + dl_dst->dl_tensor.dtype.code = kDLUInt; + dl_dst->dl_tensor.dtype.bits = 8; + dl_dst->dl_tensor.dtype.lanes = 1; + dl_dst->dl_tensor.shape = shape; + dl_dst->dl_tensor.strides = nullptr; + dl_dst->dl_tensor.byte_offset = desc.layers[0].offset[0]; + + auto dst = at::fromDLPack(dl_dst.release()); + + return dst; +} + +VADisplay getVaDisplayFromAV(UniqueAVFrame& avFrame) { + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; + AVHWDeviceContext* hwdc = hwfc->device_ctx; + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; + return vactx->display; +} + +struct vaapiVpContext { + VADisplay dpy_; + VAConfigID config_id_ = VA_INVALID_ID; + VAContextID context_id_ = VA_INVALID_ID; + VABufferID pipeline_buf_id_ = VA_INVALID_ID; + + // These structures must be available thru all life + // circle of the struct since they are reused by the media + // driver internally during vaRenderPicture(). + VAProcPipelineParameterBuffer pipeline_{}; + VARectangle surface_region_{}; + + vaapiVpContext() = delete; + vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height); + + ~vaapiVpContext() { + if (pipeline_buf_id_ != VA_INVALID_ID) + vaDestroyBuffer(dpy_, pipeline_buf_id_); + if (context_id_ != VA_INVALID_ID) + vaDestroyContext(dpy_, context_id_); + if (config_id_ != VA_INVALID_ID) + vaDestroyConfig(dpy_, config_id_); + } + + void convertTo(VASurfaceID id); +}; + +vaapiVpContext::vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height) + : dpy_(dpy) { + VAStatus res = vaCreateConfig( + dpy_, VAProfileNone, VAEntrypointVideoProc, nullptr, 0, &config_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI config: ", + vaErrorStr(res)); + + res = vaCreateContext( + dpy_, + config_id_, + width, + height, + VA_PROGRESSIVE, + nullptr, + 0, + &context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI VP context: ", + vaErrorStr(res)); + + surface_region_.width = width; + surface_region_.height = height; + + pipeline_.surface = (VASurfaceID)(uintptr_t)avFrame->data[3]; + pipeline_.surface_region = &surface_region_; + pipeline_.output_region = &surface_region_; + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) + pipeline_.surface_color_standard = VAProcColorStandardBT709; + + res = vaCreateBuffer( + dpy_, + context_id_, + VAProcPipelineParameterBufferType, + sizeof(pipeline_), + 1, + &pipeline_, + &pipeline_buf_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaCreateBuffer failed: ", vaErrorStr(res)); +} + +void vaapiVpContext::convertTo(VASurfaceID id) { + VAStatus res = vaBeginPicture(dpy_, context_id_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaBeginPicture failed: ", vaErrorStr(res)); + + res = vaRenderPicture(dpy_, context_id_, &pipeline_buf_id_, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaRenderPicture failed: ", vaErrorStr(res)); + + res = vaEndPicture(dpy_, context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaEndPicture failed: ", vaErrorStr(res)); + + res = vaSyncSurface(dpy_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaSyncSurface failed: ", vaErrorStr(res)); +} + +torch::Tensor convertAVFrameToTensor( + const torch::Device& device, + UniqueAVFrame& avFrame, + int width, + int height) { + TORCH_CHECK(height > 0, "height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + + // Allocating intermediate tensor we can convert input to with VAAPI. + // This tensor should be WxHx4 since VAAPI does not support RGB24 + // and works only with RGB32. + VADisplay va_dpy = getVaDisplayFromAV(avFrame); + // Importing tensor to VAAPI. + vaapiSurface va_surface(va_dpy, width, height); + + vaapiVpContext va_vp(va_dpy, avFrame, width, height); + va_vp.convertTo(va_surface.id()); + + return va_surface.toTensor(device); +} + +void convertAVFrameToFrameOutputOnXpu( + const torch::Device& device, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + VideoDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_VAAPI, + "Expected format to be AV_PIX_FMT_VAAPI, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int height = frameDims.height; + int width = frameDims.width; + torch::Tensor& dst = frameOutput.data; + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + auto shape = dst.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + } else { + dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device); + } + + auto start = std::chrono::high_resolution_clock::now(); + + // We convert input to the RGBX color format with VAAPI getting WxHx4 + // tensor on the output. + torch::Tensor dst_rgb4 = + convertAVFrameToTensor(device, avFrame, width, height); + dst.copy_(dst_rgb4.narrow(2, 0, 3)); + + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; +} + +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional findXpuCodec( + const torch::Device& device, + const AVCodecID& codecId) { + throwErrorIfNonXpuDevice(device); + + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_decoder(codec)) { + continue; + } + + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_VAAPI) { + return codec; + } + } + } + + return std::nullopt; +} + +} // namespace facebook::torchcodec diff --git a/test/conftest.py b/test/conftest.py index f5db4b5d..df7b234c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,6 +10,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "needs_cuda: mark for tests that rely on a CUDA device" ) + config.addinivalue_line( + "markers", "needs_xpu: mark for tests that rely on a XPU device" + ) def pytest_collection_modifyitems(items): @@ -19,15 +22,16 @@ def pytest_collection_modifyitems(items): out_items = [] for item in items: - # The needs_cuda mark will exist if the test was explicitly decorated - # with the @needs_cuda decorator. It will also exist if it was + # The needs_[cuda|xpu] mark will exist if the test was explicitly decorated + # with the respective @needs_* decorator. It will also exist if it was # parametrized with a parameter that has the mark: for example if a test # is parametrized with - # @pytest.mark.parametrize('device', cpu_and_cuda()) + # @pytest.mark.parametrize('device', cpu_and_accelerators()) # the "instances" of the tests where device == 'cuda' will have the # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the # mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if ( needs_cuda @@ -42,6 +46,13 @@ def pytest_collection_modifyitems(items): # those for whatever reason, we need to know. item.add_marker(pytest.mark.skip(reason="CUDA not available.")) + if ( + needs_xpu + and not torch.xpu.is_available() + and os.environ.get("FAIL_WITHOUT_XPU") is None + ): + item.add_marker(pytest.mark.skip(reason="XPU not available.")) + out_items.append(item) items[:] = out_items @@ -56,6 +67,8 @@ def prevent_leaking_rng(): builtin_rng_state = random.getstate() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() + if torch.xpu.is_available(): + xpu_rng_state = torch.xpu.get_rng_state() yield @@ -63,3 +76,5 @@ def prevent_leaking_rng(): random.setstate(builtin_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) + if torch.xpu.is_available(): + torch.xpu.set_rng_state(xpu_rng_state) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index cc47e116..05a3ebc0 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -18,7 +18,7 @@ from ..utils import ( assert_frames_equal, AV1_VIDEO, - cpu_and_cuda, + cpu_and_accelerators, get_ffmpeg_major_version, H265_VIDEO, in_fbcode, @@ -93,7 +93,7 @@ def test_create_fails(self): VideoDecoder(NASA_VIDEO.path, seek_mode="blah") @pytest.mark.parametrize("num_ffmpeg_threads", (1, 4)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode): decoder = VideoDecoder( @@ -143,7 +143,7 @@ def test_getitem_numpy_int(self): assert_frames_equal(ref_frame1, decoder[numpy.uint32(1)]) assert_frames_equal(ref_frame180, decoder[numpy.uint32(180)]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_slice(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -285,7 +285,7 @@ def test_getitem_slice(self, device, seek_mode): # See https://github.com/pytorch/torchcodec/issues/428 assert_frames_equal(sliced, ref) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -302,7 +302,7 @@ def test_getitem_fails(self, device, seek_mode): with pytest.raises(TypeError, match="Unsupported key type"): frame = decoder[2.3] # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_iteration(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -348,7 +348,7 @@ def test_iteration_slow(self): assert iterations == len(decoder) == 390 - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -381,7 +381,7 @@ def test_get_frame_at(self, device, seek_mode): frame9 = decoder.get_frame_at(numpy.uint32(9)) assert_frames_equal(ref_frame9, frame9.data) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_tuple_unpacking(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) @@ -392,7 +392,7 @@ def test_get_frame_at_tuple_unpacking(self, device): assert frame.pts_seconds == pts assert frame.duration_seconds == duration - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -403,7 +403,7 @@ def test_get_frame_at_fails(self, device, seek_mode): with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(10000) # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -443,7 +443,7 @@ def test_get_frames_at(self, device, seek_mode): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -457,7 +457,7 @@ def test_get_frames_at_fails(self, device, seek_mode): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_av1(self, device): if device == "cuda" and get_ffmpeg_major_version() == 4: return @@ -470,7 +470,7 @@ def test_get_frame_at_av1(self, device): assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -499,7 +499,7 @@ def test_get_frame_played_at_h265(self): ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -510,7 +510,7 @@ def test_get_frame_played_at_fails(self, device, seek_mode): with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(100.0) # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_played_at(self, device, seek_mode): @@ -549,7 +549,7 @@ def test_get_frames_played_at(self, device, seek_mode): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_played_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -563,7 +563,7 @@ def test_get_frames_played_at_fails(self, device, seek_mode): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_played_at(["bad"]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("stream_index", [0, 3, None]) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_in_range(self, stream_index, device, seek_mode): @@ -681,7 +681,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): lambda decoder: decoder.get_frames_played_in_range(0, 1).data, ), ) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode): decoder = VideoDecoder( @@ -709,7 +709,7 @@ def test_dimension_order_fails(self): VideoDecoder(NASA_VIDEO.path, dimension_order="NCDHW") @pytest.mark.parametrize("stream_index", [0, 3, None]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( @@ -848,7 +848,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): ) assert_frames_equal(all_frames.data, decoder[:]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) @@ -862,7 +862,7 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): with pytest.raises(ValueError, match="Invalid stop seconds"): frame = decoder.get_frames_played_in_range(0, 23) # noqa - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_key_frame_indices(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") key_frame_indices = decoder._get_key_frame_indices() @@ -907,7 +907,7 @@ def test_get_key_frame_indices(self, device): # TODO investigate why this fails internally. @pytest.mark.skipif(in_fbcode(), reason="Compile test fails internally.") - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_compile(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 9efb33f3..233c2555 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -40,11 +40,12 @@ from ..utils import ( assert_frames_equal, - cpu_and_cuda, + cpu_and_accelerators, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, + needs_xpu, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, @@ -56,7 +57,7 @@ class TestVideoOps: - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -73,7 +74,7 @@ def test_seek_and_next(self, device): ) assert_frames_equal(frame_time6, reference_frame_time6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_seek_to_negative_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -85,7 +86,7 @@ def test_seek_to_negative_pts(self, device): frame0, _, _ = get_next_frame(decoder) assert_frames_equal(frame0, reference_frame0.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -109,7 +110,7 @@ def test_get_frame_at_pts(self, device): with pytest.raises(AssertionError): assert_frames_equal(next_frame, reference_frame6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -123,7 +124,7 @@ def test_get_frame_at_index(self, device): ) assert_frames_equal(frame6, reference_frame6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frame_with_info_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -135,7 +136,7 @@ def test_get_frame_with_info_at_index(self, device): assert pts.item() == pytest.approx(6.006, rel=1e-3) assert duration.item() == pytest.approx(0.03337, rel=1e-3) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_at_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -147,7 +148,7 @@ def test_get_frames_at_indices(self, device): assert_frames_equal(frames0and180[0], reference_frame0.to(device)) assert_frames_equal(frames0and180[1], reference_frame180.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_at_indices_unsorted_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) @@ -174,7 +175,7 @@ def test_get_frames_at_indices_unsorted_indices(self, device): with pytest.raises(AssertionError): assert_frames_equal(frames[0], frames[-1]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_by_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) @@ -202,7 +203,7 @@ def test_get_frames_by_pts(self, device): with pytest.raises(AssertionError): assert_frames_equal(frames[0], frames[-1]) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_pts_apis_against_index_ref(self, device): # Non-regression test for https://github.com/pytorch/torchcodec/pull/287 # Get all frames in the video, then query all frames with all time-based @@ -257,7 +258,7 @@ def test_pts_apis_against_index_ref(self, device): ) torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_get_frames_in_range(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -297,7 +298,7 @@ def test_get_frames_in_range(self, device): empty_frame, *_ = get_frames_in_range(decoder, start=5, stop=5) assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_throws_exception_at_eof(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -312,7 +313,7 @@ def test_throws_exception_at_eof(self, device): with pytest.raises(IndexError, match="no more frames"): get_frame_at_pts(decoder, seconds=1000.0) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_throws_exception_if_seek_too_far(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) @@ -321,7 +322,7 @@ def test_throws_exception_if_seek_too_far(self, device): with pytest.raises(IndexError, match="no more frames"): get_next_frame(decoder) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) def test_compile_seek_and_next(self, device): # TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now # compilation fails because it can't handle tensors of size unknown at @@ -345,7 +346,7 @@ def get_frame1_and_frame_time6(decoder): assert_frames_equal(frame0, reference_frame0.to(device)) assert_frames_equal(frame_time6, reference_frame_time6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_accelerators()) @pytest.mark.parametrize( "create_from", ("file", "tensor", "bytes", "file_like_rawio", "file_like_bufferedio"), @@ -631,6 +632,19 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + @needs_xpu + def test_xpu_decoder(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device="xpu") + frame0, pts, duration = get_next_frame(decoder) + assert frame0.device.type == "xpu" + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + assert_frames_equal(frame0, reference_frame0.to("xpu")) + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) + class TestAudioOps: @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index 70f32bfb..dd379ce7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -22,8 +22,19 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) -def cpu_and_cuda(): - return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) +# Decorator for skipping XPU tests when XPU isn't available. The tests are +# effectively marked to be skipped in pytest_collection_modifyitems() of +# conftest.py +def needs_xpu(test_item): + return pytest.mark.needs_xpu(test_item) + + +def cpu_and_accelerators(): + return ( + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("xpu", marks=pytest.mark.needs_xpu), + ) def get_ffmpeg_major_version(): @@ -45,6 +56,13 @@ def assert_frames_equal(*args, **kwargs): ) else: torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0) + elif args[0].device.type == "xpu": + if not torch.allclose(*args, atol=0, rtol=0): + from torcheval.metrics import PeakSignalNoiseRatio + + metric = PeakSignalNoiseRatio() + metric.update(args[0], args[1]) + assert metric.compute() >= 40 else: torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) else: