diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 8086d0b4b..90cd2cc98 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -196,10 +196,48 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { + // We check that avFrame->format == AV_PIX_FMT_CUDA. This only ensures the + // AVFrame is on GPU memory. It can be on CPU memory if the video isn't + // supported by NVDEC for whatever reason: NVDEC falls back to CPU decoding in + // this case, and our check fails. + // TODO: we could send the frame back into the CPU path, and rely on + // swscale/filtergraph to run the color conversion to properly output the + // frame. TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, - "Expected format to be AV_PIX_FMT_CUDA, got " + - std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); + "Expected format to be AV_PIX_FMT_CUDA, got ", + (av_get_pix_fmt_name((AVPixelFormat)avFrame->format) + ? av_get_pix_fmt_name((AVPixelFormat)avFrame->format) + : "unknown"), + ". When that happens, it is probably because the video is not supported by NVDEC. " + "Try using the CPU device instead. " + "If the video is 10bit, we are tracking 10bit support in " + "https://github.com/pytorch/torchcodec/issues/776"); + + // Above we checked that the AVFrame was on GPU, but that's not enough, we + // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), + // because this is what the NPP color conversion routines expect. + // TODO: we should investigate how to can perform color conversion for + // non-8bit videos. This is supported on CPU. + TORCH_CHECK( + avFrame->hw_frames_ctx != nullptr, + "The AVFrame does not have a hw_frames_ctx. " + "That's unexpected, please report this to the TorchCodec repo."); + + AVPixelFormat actualFormat = + reinterpret_cast(avFrame->hw_frames_ctx->data) + ->sw_format; + TORCH_CHECK( + actualFormat == AV_PIX_FMT_NV12, + "The AVFrame is ", + (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) + : "unknown"), + ", but we expected AV_PIX_FMT_NV12. This typically happens when " + "the video isn't 8bit, which is not supported on CUDA at the moment. " + "Try using the CPU device instead. " + "If the video is 10bit, we are tracking 10bit support in " + "https://github.com/pytorch/torchcodec/issues/776"); + auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); int height = frameDims.height; diff --git a/test/resources/h264_10bits.mp4 b/test/resources/h264_10bits.mp4 new file mode 100644 index 000000000..804362a35 Binary files /dev/null and b/test/resources/h264_10bits.mp4 differ diff --git a/test/resources/h265_10bits.mp4 b/test/resources/h265_10bits.mp4 new file mode 100644 index 000000000..25e35ef6a Binary files /dev/null and b/test/resources/h265_10bits.mp4 differ diff --git a/test/test_decoders.py b/test/test_decoders.py index dcf9a1585..5b104e060 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -26,12 +26,15 @@ AV1_VIDEO, cpu_and_cuda, get_ffmpeg_major_version, + H264_10BITS, + H265_10BITS, H265_VIDEO, in_fbcode, NASA_AUDIO, NASA_AUDIO_MP3, NASA_AUDIO_MP3_44100, NASA_VIDEO, + needs_cuda, SINE_MONO_S16, SINE_MONO_S32, SINE_MONO_S32_44100, @@ -1138,6 +1141,31 @@ def test_pts_to_dts_fallback(self, seek_mode): with pytest.raises(AssertionError, match="not equal"): torch.testing.assert_close(decoder[0], decoder[10]) + @needs_cuda + @pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS)) + def test_10bit_videos_cuda(self, asset): + # Assert that we raise proper error on different kinds of 10bit videos. + + # TODO we should investigate how to support 10bit videos on GPU. + # See https://github.com/pytorch/torchcodec/issues/776 + + decoder = VideoDecoder(asset.path, device="cuda") + + if asset is H265_10BITS: + match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12." + else: + match = "Expected format to be AV_PIX_FMT_CUDA, got yuv420p10le." + with pytest.raises(RuntimeError, match=match): + decoder.get_frame_at(0) + + @pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS)) + def test_10bit_videos_cpu(self, asset): + # This just validates that we can decode 10-bit videos on CPU. + # TODO validate against the ref that the decoded frames are correct + + decoder = VideoDecoder(asset.path) + decoder.get_frame_at(10) + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32)) diff --git a/test/utils.py b/test/utils.py index e3368e3f4..58dd8b5da 100644 --- a/test/utils.py +++ b/test/utils.py @@ -321,6 +321,28 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: frames={}, # Automatically loaded from json file ) +# Video generated with: +# ffmpeg -f lavfi -i testsrc2=duration=1:size=200x200:rate=30 -c:v libx265 -pix_fmt yuv420p10le -preset fast -crf 23 h265_10bits.mp4 +H265_10BITS = TestVideo( + filename="h265_10bits.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=200, height=200, num_color_channels=3), + }, + frames={0: {}}, # Not needed yet +) + +# Video generated with: +# peg -f lavfi -i testsrc2=duration=1:size=200x200:rate=30 -c:v libx264 -pix_fmt yuv420p10le -preset fast -crf 23 h264_10bits.mp4 +H264_10BITS = TestVideo( + filename="h264_10bits.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=200, height=200, num_color_channels=3), + }, + frames={0: {}}, # Not needed yet +) + @dataclass class TestAudio(TestContainerFile):