Skip to content

Commit fc60ed6

Browse files
authored
Use cuda filters to support 10-bit videos (#899)
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent 30704cf commit fc60ed6

File tree

7 files changed

+167
-44
lines changed

7 files changed

+167
-44
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,24 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
8383
enum AVPixelFormat frameFormat =
8484
static_cast<enum AVPixelFormat>(avFrame->format);
8585

86+
// This is an early-return optimization: if the format is already what we
87+
// need, and the dimensions are also what we need, we don't need to call
88+
// swscale or filtergraph. We can just convert the AVFrame to a tensor.
89+
if (frameFormat == AV_PIX_FMT_RGB24 &&
90+
avFrame->width == expectedOutputWidth &&
91+
avFrame->height == expectedOutputHeight) {
92+
outputTensor = toTensor(avFrame);
93+
if (preAllocatedOutputTensor.has_value()) {
94+
// We have already validated that preAllocatedOutputTensor and
95+
// outputTensor have the same shape.
96+
preAllocatedOutputTensor.value().copy_(outputTensor);
97+
frameOutput.data = preAllocatedOutputTensor.value();
98+
} else {
99+
frameOutput.data = outputTensor;
100+
}
101+
return;
102+
}
103+
86104
// By default, we want to use swscale for color conversion because it is
87105
// faster. However, it has width requirements, so we may need to fall back
88106
// to filtergraph. We also need to respect what was requested from the
@@ -159,7 +177,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
159177
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160178
prevFiltersContext_ = std::move(filtersContext);
161179
}
162-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
180+
outputTensor = toTensor(filterGraphContext_->convert(avFrame));
163181

164182
// Similarly to above, if this check fails it means the frame wasn't
165183
// reshaped to its expected dimensions by filtergraph.
@@ -208,23 +226,20 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208226
return resultHeight;
209227
}
210228

211-
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
212-
const UniqueAVFrame& avFrame) {
213-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
214-
215-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
229+
torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) {
230+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
216231

217-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
232+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*avFrame.get());
218233
int height = frameDims.height;
219234
int width = frameDims.width;
220235
std::vector<int64_t> shape = {height, width, 3};
221-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
222-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
223-
auto deleter = [filteredAVFramePtr](void*) {
224-
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
236+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
237+
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
238+
auto deleter = [avFrameClone](void*) {
239+
UniqueAVFrame avFrameToDelete(avFrameClone);
225240
};
226241
return torch::from_blob(
227-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
242+
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
228243
}
229244

230245
void CpuDeviceInterface::createSwsContext(

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class CpuDeviceInterface : public DeviceInterface {
3939
const UniqueAVFrame& avFrame,
4040
torch::Tensor& outputTensor);
4141

42-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
42+
torch::Tensor toTensor(const UniqueAVFrame& avFrame);
4443

4544
struct SwsFrameContext {
4645
int inputWidth = 0;

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 119 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,127 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199199
return;
200200
}
201201

202+
std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
203+
const VideoStreamOptions& videoStreamOptions,
204+
const UniqueAVFrame& avFrame,
205+
const AVRational& timeBase) {
206+
// We need FFmpeg filters to handle those conversion cases which are not
207+
// directly implemented in CUDA or CPU device interface (in case of a
208+
// fallback).
209+
enum AVPixelFormat frameFormat =
210+
static_cast<enum AVPixelFormat>(avFrame->format);
211+
212+
// Input frame is on CPU, we will just pass it to CPU device interface, so
213+
// skipping filters context as CPU device interface will handle everythong for
214+
// us.
215+
if (avFrame->format != AV_PIX_FMT_CUDA) {
216+
return nullptr;
217+
}
218+
219+
TORCH_CHECK(
220+
avFrame->hw_frames_ctx != nullptr,
221+
"The AVFrame does not have a hw_frames_ctx. "
222+
"That's unexpected, please report this to the TorchCodec repo.");
223+
224+
auto hwFramesCtx =
225+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
226+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
227+
228+
// NV12 conversion is implemented directly with NPP, no need for filters.
229+
if (actualFormat == AV_PIX_FMT_NV12) {
230+
return nullptr;
231+
}
232+
233+
auto frameDims =
234+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
235+
int height = frameDims.height;
236+
int width = frameDims.width;
237+
238+
AVPixelFormat outputFormat;
239+
std::stringstream filters;
240+
241+
unsigned version_int = avfilter_version();
242+
if (version_int < AV_VERSION_INT(8, 0, 103)) {
243+
// Color conversion support ('format=' option) was added to scale_cuda from
244+
// n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245+
// filters. See:
246+
// https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247+
outputFormat = AV_PIX_FMT_RGB24;
248+
249+
auto actualFormatName = av_get_pix_fmt_name(actualFormat);
250+
TORCH_CHECK(
251+
actualFormatName != nullptr,
252+
"The actual format of a frame is unknown to FFmpeg. "
253+
"That's unexpected, please report this to the TorchCodec repo.");
254+
255+
filters << "hwdownload,format=" << actualFormatName;
256+
filters << ",scale=" << width << ":" << height;
257+
filters << ":sws_flags=bilinear";
258+
} else {
259+
// Actual output color format will be set via filter options
260+
outputFormat = AV_PIX_FMT_CUDA;
261+
262+
filters << "scale_cuda=" << width << ":" << height;
263+
filters << ":format=nv12:interp_algo=bilinear";
264+
}
265+
266+
return std::make_unique<FiltersContext>(
267+
avFrame->width,
268+
avFrame->height,
269+
frameFormat,
270+
avFrame->sample_aspect_ratio,
271+
width,
272+
height,
273+
outputFormat,
274+
filters.str(),
275+
timeBase,
276+
av_buffer_ref(avFrame->hw_frames_ctx));
277+
}
278+
202279
void CudaDeviceInterface::convertAVFrameToFrameOutput(
203280
const VideoStreamOptions& videoStreamOptions,
204281
[[maybe_unused]] const AVRational& timeBase,
205-
UniqueAVFrame& avFrame,
282+
UniqueAVFrame& avInputFrame,
206283
FrameOutput& frameOutput,
207284
std::optional<torch::Tensor> preAllocatedOutputTensor) {
285+
std::unique_ptr<FiltersContext> newFiltersContext =
286+
initializeFiltersContext(videoStreamOptions, avInputFrame, timeBase);
287+
UniqueAVFrame avFilteredFrame;
288+
if (newFiltersContext) {
289+
// We need to compare the current filter context with our previous filter
290+
// context. If they are different, then we need to re-create a filter
291+
// graph. We create a filter graph late so that we don't have to depend
292+
// on the unreliable metadata in the header. And we sometimes re-create
293+
// it because it's possible for frame resolution to change mid-stream.
294+
// Finally, we want to reuse the filter graph as much as possible for
295+
// performance reasons.
296+
if (!filterGraph_ || *filtersContext_ != *newFiltersContext) {
297+
filterGraph_ =
298+
std::make_unique<FilterGraph>(*newFiltersContext, videoStreamOptions);
299+
filtersContext_ = std::move(newFiltersContext);
300+
}
301+
avFilteredFrame = filterGraph_->convert(avInputFrame);
302+
303+
// If this check fails it means the frame wasn't
304+
// reshaped to its expected dimensions by filtergraph.
305+
TORCH_CHECK(
306+
(avFilteredFrame->width == filtersContext_->outputWidth) &&
307+
(avFilteredFrame->height == filtersContext_->outputHeight),
308+
"Expected frame from filter graph of ",
309+
filtersContext_->outputWidth,
310+
"x",
311+
filtersContext_->outputHeight,
312+
", got ",
313+
avFilteredFrame->width,
314+
"x",
315+
avFilteredFrame->height);
316+
}
317+
318+
UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame;
319+
320+
// The filtered frame might be on CPU if CPU fallback has happenned on filter
321+
// graph level. For example, that's how we handle color format conversion
322+
// on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
208323
if (avFrame->format != AV_PIX_FMT_CUDA) {
209324
// The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
210325
// the GPU. In this branch, the frame is on the CPU: this is what NVDEC
@@ -232,8 +347,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
232347
// Above we checked that the AVFrame was on GPU, but that's not enough, we
233348
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
234349
// because this is what the NPP color conversion routines expect.
235-
// TODO: we should investigate how to can perform color conversion for
236-
// non-8bit videos. This is supported on CPU.
237350
TORCH_CHECK(
238351
avFrame->hw_frames_ctx != nullptr,
239352
"The AVFrame does not have a hw_frames_ctx. "
@@ -242,16 +355,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
242355
auto hwFramesCtx =
243356
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
244357
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
358+
245359
TORCH_CHECK(
246360
actualFormat == AV_PIX_FMT_NV12,
247361
"The AVFrame is ",
248362
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
249363
: "unknown"),
250-
", but we expected AV_PIX_FMT_NV12. This typically happens when "
251-
"the video isn't 8bit, which is not supported on CUDA at the moment. "
252-
"Try using the CPU device instead. "
253-
"If the video is 10bit, we are tracking 10bit support in "
254-
"https://github.com/pytorch/torchcodec/issues/776");
364+
", but we expected AV_PIX_FMT_NV12. "
365+
"That's unexpected, please report this to the TorchCodec repo.");
255366

256367
auto frameDims =
257368
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <npp.h>
1010
#include "src/torchcodec/_core/DeviceInterface.h"
11+
#include "src/torchcodec/_core/FilterGraph.h"
1112

1213
namespace facebook::torchcodec {
1314

@@ -30,8 +31,17 @@ class CudaDeviceInterface : public DeviceInterface {
3031
std::nullopt) override;
3132

3233
private:
34+
std::unique_ptr<FiltersContext> initializeFiltersContext(
35+
const VideoStreamOptions& videoStreamOptions,
36+
const UniqueAVFrame& avFrame,
37+
const AVRational& timeBase);
38+
3339
UniqueAVBufferRef ctx_;
3440
std::unique_ptr<NppStreamContext> nppCtx_;
41+
// Current filter context. Used to know whether a new FilterGraph
42+
// should be created to process the next frame.
43+
std::unique_ptr<FiltersContext> filtersContext_;
44+
std::unique_ptr<FilterGraph> filterGraph_;
3545
};
3646

3747
} // namespace facebook::torchcodec

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ FiltersContext::FiltersContext(
2222
int outputHeight,
2323
AVPixelFormat outputFormat,
2424
const std::string& filtergraphStr,
25-
AVRational timeBase)
25+
AVRational timeBase,
26+
AVBufferRef* hwFramesCtx)
2627
: inputWidth(inputWidth),
2728
inputHeight(inputHeight),
2829
inputFormat(inputFormat),
@@ -31,7 +32,8 @@ FiltersContext::FiltersContext(
3132
outputHeight(outputHeight),
3233
outputFormat(outputFormat),
3334
filtergraphStr(filtergraphStr),
34-
timeBase(timeBase) {}
35+
timeBase(timeBase),
36+
hwFramesCtx(hwFramesCtx) {}
3537

3638
bool operator==(const AVRational& lhs, const AVRational& rhs) {
3739
return lhs.num == rhs.num && lhs.den == rhs.den;

src/torchcodec/_core/FilterGraph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct FiltersContext {
3535
int outputHeight,
3636
AVPixelFormat outputFormat,
3737
const std::string& filtergraphStr,
38-
AVRational timeBase);
38+
AVRational timeBase,
39+
AVBufferRef* hwFramesCtx = nullptr);
3940

4041
bool operator==(const FiltersContext&) const;
4142
bool operator!=(const FiltersContext&) const;

test/test_decoders.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,22 +1228,6 @@ def test_full_and_studio_range_bt709_video(self, asset):
12281228
elif cuda_version_used_for_building_torch() == (12, 8):
12291229
assert psnr(gpu_frame, cpu_frame) > 20
12301230

1231-
@needs_cuda
1232-
def test_10bit_videos_cuda(self):
1233-
# Assert that we raise proper error on different kinds of 10bit videos.
1234-
1235-
# TODO we should investigate how to support 10bit videos on GPU.
1236-
# See https://github.com/pytorch/torchcodec/issues/776
1237-
1238-
asset = H265_10BITS
1239-
1240-
decoder = VideoDecoder(asset.path, device="cuda")
1241-
with pytest.raises(
1242-
RuntimeError,
1243-
match="The AVFrame is p010le, but we expected AV_PIX_FMT_NV12.",
1244-
):
1245-
decoder.get_frame_at(0)
1246-
12471231
@needs_cuda
12481232
def test_10bit_gpu_fallsback_to_cpu(self):
12491233
# Test for 10-bit videos that aren't supported by NVDEC: we decode and
@@ -1275,12 +1259,13 @@ def test_10bit_gpu_fallsback_to_cpu(self):
12751259
frames_cpu = decoder_cpu.get_frames_at(frame_indices).data
12761260
assert_frames_equal(frames_gpu.cpu(), frames_cpu)
12771261

1262+
@pytest.mark.parametrize("device", all_supported_devices())
12781263
@pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS))
1279-
def test_10bit_videos_cpu(self, asset):
1280-
# This just validates that we can decode 10-bit videos on CPU.
1264+
def test_10bit_videos(self, device, asset):
1265+
# This just validates that we can decode 10-bit videos.
12811266
# TODO validate against the ref that the decoded frames are correct
12821267

1283-
decoder = VideoDecoder(asset.path)
1268+
decoder = VideoDecoder(asset.path, device=device)
12841269
decoder.get_frame_at(10)
12851270

12861271
def setup_frame_mappings(tmp_path, file, stream_index):

0 commit comments

Comments
 (0)