Skip to content

Commit 9823338

Browse files
committed
Use cuda filters to support 10-bit videos
For: #776 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent 8662ee1 commit 9823338

File tree

8 files changed

+129
-22
lines changed

8 files changed

+129
-22
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,68 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199199
return;
200200
}
201201

202+
std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
203+
const VideoStreamOptions& videoStreamOptions,
204+
const UniqueAVFrame& avFrame,
205+
const AVRational& timeBase) {
206+
enum AVPixelFormat frameFormat =
207+
static_cast<enum AVPixelFormat>(avFrame->format);
208+
209+
if (avFrame->format != AV_PIX_FMT_CUDA) {
210+
auto cpuDevice = torch::Device(torch::kCPU);
211+
auto cpuInterface = createDeviceInterface(cpuDevice);
212+
return cpuInterface->initializeFiltersContext(
213+
videoStreamOptions, avFrame, timeBase);
214+
}
215+
216+
auto frameDims =
217+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
218+
int height = frameDims.height;
219+
int width = frameDims.width;
220+
221+
auto hwFramesCtx =
222+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
223+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
224+
225+
if (actualFormat == AV_PIX_FMT_NV12) {
226+
return nullptr;
227+
}
228+
229+
AVPixelFormat outputFormat;
230+
std::stringstream filters;
231+
232+
unsigned version_int = avfilter_version();
233+
if (version_int < AV_VERSION_INT(8, 0, 103)) {
234+
// Color conversion support ('format=' option) was added to scale_cuda from
235+
// n5.0. With the earlier version of ffmpeg we have no choice but use CPU
236+
// filters. See:
237+
// https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
238+
outputFormat = AV_PIX_FMT_RGB24;
239+
240+
filters << "hwdownload,format=" << av_pix_fmt_desc_get(actualFormat)->name;
241+
filters << ",scale=" << width << ":" << height;
242+
filters << ":sws_flags=bilinear";
243+
} else {
244+
// Actual output color format will be set via filter options
245+
outputFormat = AV_PIX_FMT_CUDA;
246+
247+
filters << "scale_cuda=" << width << ":" << height;
248+
filters << ":format=nv12:interp_algo=bilinear";
249+
}
250+
251+
return std::make_unique<FiltersContext>(
252+
avFrame->width,
253+
avFrame->height,
254+
frameFormat,
255+
avFrame->sample_aspect_ratio,
256+
width,
257+
height,
258+
outputFormat,
259+
filters.str(),
260+
timeBase,
261+
av_buffer_ref(avFrame->hw_frames_ctx));
262+
}
263+
202264
void CudaDeviceInterface::convertAVFrameToFrameOutput(
203265
const VideoStreamOptions& videoStreamOptions,
204266
[[maybe_unused]] const AVRational& timeBase,

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class CudaDeviceInterface : public DeviceInterface {
2121

2222
void initializeContext(AVCodecContext* codecContext) override;
2323

24+
std::unique_ptr<FiltersContext> initializeFiltersContext(
25+
const VideoStreamOptions& videoStreamOptions,
26+
const UniqueAVFrame& avFrame,
27+
const AVRational& timeBase) override;
28+
2429
void convertAVFrameToFrameOutput(
2530
const VideoStreamOptions& videoStreamOptions,
2631
const AVRational& timeBase,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <stdexcept>
1313
#include <string>
1414
#include "FFMPEGCommon.h"
15+
#include "src/torchcodec/_core/FilterGraph.h"
1516
#include "src/torchcodec/_core/Frame.h"
1617
#include "src/torchcodec/_core/StreamOptions.h"
1718

@@ -41,6 +42,18 @@ class DeviceInterface {
4142
// support CUDA and others only support CPU.
4243
virtual void initializeContext(AVCodecContext* codecContext) = 0;
4344

45+
// Returns FilterContext if device interface can't handle conversion of the
46+
// frame on its own within a call to convertAVFrameToFrameOutput().
47+
// FilterContext contains input and output initialization parameters
48+
// describing required conversion. Output can further be passed to
49+
// convertAVFrameToFrameOutput() to generate output tensor.
50+
virtual std::unique_ptr<FiltersContext> initializeFiltersContext(
51+
[[maybe_unused]] const VideoStreamOptions& videoStreamOptions,
52+
[[maybe_unused]] const UniqueAVFrame& avFrame,
53+
[[maybe_unused]] const AVRational& timeBase) {
54+
return nullptr;
55+
};
56+
4457
virtual void convertAVFrameToFrameOutput(
4558
const VideoStreamOptions& videoStreamOptions,
4659
const AVRational& timeBase,

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ FiltersContext::FiltersContext(
2222
int outputHeight,
2323
AVPixelFormat outputFormat,
2424
const std::string& filtergraphStr,
25-
AVRational timeBase)
25+
AVRational timeBase,
26+
AVBufferRef* hwFramesCtx)
2627
: inputWidth(inputWidth),
2728
inputHeight(inputHeight),
2829
inputFormat(inputFormat),
@@ -31,7 +32,8 @@ FiltersContext::FiltersContext(
3132
outputHeight(outputHeight),
3233
outputFormat(outputFormat),
3334
filtergraphStr(filtergraphStr),
34-
timeBase(timeBase) {}
35+
timeBase(timeBase),
36+
hwFramesCtx(hwFramesCtx) {}
3537

3638
bool operator==(const AVRational& lhs, const AVRational& rhs) {
3739
return lhs.num == rhs.num && lhs.den == rhs.den;

src/torchcodec/_core/FilterGraph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct FiltersContext {
3535
int outputHeight,
3636
AVPixelFormat outputFormat,
3737
const std::string& filtergraphStr,
38-
AVRational timeBase);
38+
AVRational timeBase,
39+
AVBufferRef* hwFramesCtx = nullptr);
3940

4041
bool operator==(const FiltersContext&) const;
4142
bool operator!=(const FiltersContext&) const;

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,41 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
12471247
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12481248
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
12491249
} else if (deviceInterface_) {
1250+
std::unique_ptr<FiltersContext> newFiltersContext =
1251+
deviceInterface_->initializeFiltersContext(
1252+
streamInfo.videoStreamOptions, avFrame, streamInfo.timeBase);
1253+
// Device interface might return nullptr for the filter context in which
1254+
// case device interface will handle conversion directly in
1255+
// convertAVFrameToFrameOutput().
1256+
if (newFiltersContext) {
1257+
// We need to compare the current filter context with our previous filter
1258+
// context. If they are different, then we need to re-create a filter
1259+
// graph. We create a filter graph late so that we don't have to depend
1260+
// on the unreliable metadata in the header. And we sometimes re-create
1261+
// it because it's possible for frame resolution to change mid-stream.
1262+
// Finally, we want to reuse the filter graph as much as possible for
1263+
// performance reasons.
1264+
if (!filterGraph_ || filtersContext_ != newFiltersContext) {
1265+
filterGraph_ = std::make_unique<FilterGraph>(
1266+
*newFiltersContext, streamInfo.videoStreamOptions);
1267+
filtersContext_ = std::move(newFiltersContext);
1268+
}
1269+
avFrame = filterGraph_->convert(avFrame);
1270+
1271+
// If this check fails it means the frame wasn't
1272+
// reshaped to its expected dimensions by filtergraph.
1273+
TORCH_CHECK(
1274+
(avFrame->width == filtersContext_->outputWidth) &&
1275+
(avFrame->height == filtersContext_->outputHeight),
1276+
"Expected frame from filter graph of ",
1277+
filtersContext_->outputWidth,
1278+
"x",
1279+
filtersContext_->outputHeight,
1280+
", got ",
1281+
avFrame->width,
1282+
"x",
1283+
avFrame->height);
1284+
}
12501285
deviceInterface_->convertAVFrameToFrameOutput(
12511286
streamInfo.videoStreamOptions,
12521287
streamInfo.timeBase,

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ class SingleStreamDecoder {
351351
SeekMode seekMode_;
352352
ContainerMetadata containerMetadata_;
353353
UniqueDecodingAVFormatContext formatContext_;
354+
// Current filter context. Used to know whether a new FilterGraph
355+
// should be created to process a next frame.
356+
std::unique_ptr<FiltersContext> filtersContext_;
357+
std::unique_ptr<FilterGraph> filterGraph_;
354358
std::unique_ptr<DeviceInterface> deviceInterface_;
355359
std::map<int, StreamInfo> streamInfos_;
356360
const int NO_ACTIVE_STREAM = -2;

test/test_decoders.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,22 +1225,6 @@ def test_full_and_studio_range_bt709_video(self, asset):
12251225
elif cuda_version_used_for_building_torch() == (12, 8):
12261226
assert psnr(gpu_frame, cpu_frame) > 20
12271227

1228-
@needs_cuda
1229-
def test_10bit_videos_cuda(self):
1230-
# Assert that we raise proper error on different kinds of 10bit videos.
1231-
1232-
# TODO we should investigate how to support 10bit videos on GPU.
1233-
# See https://github.com/pytorch/torchcodec/issues/776
1234-
1235-
asset = H265_10BITS
1236-
1237-
decoder = VideoDecoder(asset.path, device="cuda")
1238-
with pytest.raises(
1239-
RuntimeError,
1240-
match="The AVFrame is p010le, but we expected AV_PIX_FMT_NV12.",
1241-
):
1242-
decoder.get_frame_at(0)
1243-
12441228
@needs_cuda
12451229
def test_10bit_gpu_fallsback_to_cpu(self):
12461230
# Test for 10-bit videos that aren't supported by NVDEC: we decode and
@@ -1272,12 +1256,13 @@ def test_10bit_gpu_fallsback_to_cpu(self):
12721256
frames_cpu = decoder_cpu.get_frames_at(frame_indices).data
12731257
assert_frames_equal(frames_gpu.cpu(), frames_cpu)
12741258

1259+
@pytest.mark.parametrize("device", all_supported_devices())
12751260
@pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS))
1276-
def test_10bit_videos_cpu(self, asset):
1277-
# This just validates that we can decode 10-bit videos on CPU.
1261+
def test_10bit_videos(self, device, asset):
1262+
# This just validates that we can decode 10-bit videos.
12781263
# TODO validate against the ref that the decoded frames are correct
12791264

1280-
decoder = VideoDecoder(asset.path)
1265+
decoder = VideoDecoder(asset.path, device=device)
12811266
decoder.get_frame_at(10)
12821267

12831268
def setup_frame_mappings(tmp_path, file, stream_index):

0 commit comments

Comments
 (0)