Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 70 additions & 76 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@ static bool g_cpu = registerDeviceInterface(
torch::kCPU,
[](const torch::Device& device) { return new CpuDeviceInterface(device); });

ColorConversionLibrary getColorConversionLibrary(
const VideoStreamOptions& videoStreamOptions,
int width) {
// By default, we want to use swscale for color conversion because it is
// faster. However, it has width requirements, so we may need to fall back
// to filtergraph. We also need to respect what was requested from the
// options; we respect the options unconditionally, so it's possible for
// swscale's width requirements to be violated. We don't expose the ability to
// choose color conversion library publicly; we only use this ability
// internally.

// swscale requires widths to be multiples of 32:
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
// so we fall back to filtergraph if the width is not a multiple of 32.
auto defaultLibrary = (width % 32 == 0) ? ColorConversionLibrary::SWSCALE
: ColorConversionLibrary::FILTERGRAPH;

ColorConversionLibrary colorConversionLibrary =
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);

TORCH_CHECK(
colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
"Invalid color conversion library: ",
static_cast<int>(colorConversionLibrary));
return colorConversionLibrary;
}

} // namespace

CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
Expand Down Expand Up @@ -46,6 +74,38 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
}

std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext(
const VideoStreamOptions& videoStreamOptions,
const UniqueAVFrame& avFrame,
const AVRational& timeBase) {
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);
auto frameDims =
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
int expectedOutputHeight = frameDims.height;
int expectedOutputWidth = frameDims.width;

if (getColorConversionLibrary(videoStreamOptions, expectedOutputWidth) ==
ColorConversionLibrary::SWSCALE) {
return nullptr;
}

std::stringstream filters;
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
filters << ":sws_flags=bilinear";

return std::make_unique<FiltersContext>(
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
expectedOutputWidth,
expectedOutputHeight,
AV_PIX_FMT_RGB24,
filters.str(),
timeBase);
}

// Note [preAllocatedOutputTensor with swscale and filtergraph]:
// Callers may pass a pre-allocated tensor, where the output.data tensor will
// be stored. This parameter is honored in any case, but it only leads to a
Expand All @@ -57,7 +117,6 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
void CpuDeviceInterface::convertAVFrameToFrameOutput(
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase,
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
Expand All @@ -83,23 +142,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

// By default, we want to use swscale for color conversion because it is
// faster. However, it has width requirements, so we may need to fall back
// to filtergraph. We also need to respect what was requested from the
// options; we respect the options unconditionally, so it's possible for
// swscale's width requirements to be violated. We don't expose the ability to
// choose color conversion library publicly; we only use this ability
// internally.

// swscale requires widths to be multiples of 32:
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
// so we fall back to filtergraph if the width is not a multiple of 32.
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
? ColorConversionLibrary::SWSCALE
: ColorConversionLibrary::FILTERGRAPH;

ColorConversionLibrary colorConversionLibrary =
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
getColorConversionLibrary(videoStreamOptions, expectedOutputWidth);

if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
// We need to compare the current frame context with our previous frame
Expand Down Expand Up @@ -137,42 +181,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(

frameOutput.data = outputTensor;
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
// See comment above in swscale branch about the filterGraphContext_
// creation. creation
std::stringstream filters;
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
filters << ":sws_flags=bilinear";
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);

FiltersContext filtersContext(
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
expectedOutputWidth,
expectedOutputHeight,
AV_PIX_FMT_RGB24,
filters.str(),
timeBase);

if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
filterGraphContext_ =
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
prevFiltersContext_ = std::move(filtersContext);
}
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);

// Similarly to above, if this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
auto shape = outputTensor.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
"Expected output tensor of shape ",
expectedOutputHeight,
"x",
expectedOutputWidth,
"x3, got ",
shape);
std::vector<int64_t> shape = {expectedOutputHeight, expectedOutputWidth, 3};
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
AVFrame* avFramePtr = avFrame.release();
auto deleter = [avFramePtr](void*) {
UniqueAVFrame avFrameToDelete(avFramePtr);
};
outputTensor = torch::from_blob(
avFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});

if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
Expand All @@ -182,11 +200,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
} else {
frameOutput.data = outputTensor;
}
} else {
TORCH_CHECK(
false,
"Invalid color conversion library: ",
static_cast<int>(colorConversionLibrary));
}
}

Expand All @@ -208,25 +221,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
return resultHeight;
}

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame) {
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);

TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
int height = frameDims.height;
int width = frameDims.width;
std::vector<int64_t> shape = {height, width, 3};
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
auto deleter = [filteredAVFramePtr](void*) {
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
};
return torch::from_blob(
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
}

void CpuDeviceInterface::createSwsContext(
const SwsFrameContext& swsFrameContext,
const enum AVColorSpace colorspace) {
Expand Down
16 changes: 7 additions & 9 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ class CpuDeviceInterface : public DeviceInterface {
void initializeContext(
[[maybe_unused]] AVCodecContext* codecContext) override {}

std::unique_ptr<FiltersContext> initializeFiltersContext(
const VideoStreamOptions& videoStreamOptions,
const UniqueAVFrame& avFrame,
const AVRational& timeBase) override;

void convertAVFrameToFrameOutput(
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase,
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor =
Expand All @@ -39,9 +43,6 @@ class CpuDeviceInterface : public DeviceInterface {
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor);

torch::Tensor convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame);

struct SwsFrameContext {
int inputWidth = 0;
int inputHeight = 0;
Expand All @@ -64,15 +65,12 @@ class CpuDeviceInterface : public DeviceInterface {
const SwsFrameContext& swsFrameContext,
const enum AVColorSpace colorspace);

// color-conversion fields. Only one of FilterGraphContext and
// UniqueSwsContext should be non-null.
std::unique_ptr<FilterGraph> filterGraphContext_;
// SWS color conversion context
UniqueSwsContext swsContext_;

// Used to know whether a new FilterGraphContext or UniqueSwsContext should
// Used to know whether a new UniqueSwsContext should
// be created before decoding a new frame.
SwsFrameContext prevSwsFrameContext_;
FiltersContext prevFiltersContext_;
};

} // namespace facebook::torchcodec
69 changes: 63 additions & 6 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,70 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
return;
}

std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
const VideoStreamOptions& videoStreamOptions,
const UniqueAVFrame& avFrame,
const AVRational& timeBase) {
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

if (avFrame->format != AV_PIX_FMT_CUDA) {
auto cpuDevice = torch::Device(torch::kCPU);
auto cpuInterface = createDeviceInterface(cpuDevice);
return cpuInterface->initializeFiltersContext(
videoStreamOptions, avFrame, timeBase);
}

auto frameDims =
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
int height = frameDims.height;
int width = frameDims.width;

auto hwFramesCtx =
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
AVPixelFormat actualFormat = hwFramesCtx->sw_format;

if (actualFormat == AV_PIX_FMT_NV12) {
return nullptr;
}

AVPixelFormat outputFormat;
std::stringstream filters;

unsigned version_int = avfilter_version();
if (version_int < AV_VERSION_INT(8, 0, 103)) {
// Color conversion support ('format=' option) was added to scale_cuda from
// n5.0. With the earlier version of ffmpeg we have no choice but use CPU
// filters. See:
// https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
outputFormat = AV_PIX_FMT_RGB24;

filters << "hwdownload,format=" << av_pix_fmt_desc_get(actualFormat)->name;
filters << ",scale=" << width << ":" << height;
filters << ":sws_flags=bilinear";
} else {
// Actual output color format will be set via filter options
outputFormat = AV_PIX_FMT_CUDA;

filters << "scale_cuda=" << width << ":" << height;
filters << ":format=nv12:interp_algo=bilinear";
}

return std::make_unique<FiltersContext>(
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
width,
height,
outputFormat,
filters.str(),
timeBase,
av_buffer_ref(avFrame->hw_frames_ctx));
}

void CudaDeviceInterface::convertAVFrameToFrameOutput(
const VideoStreamOptions& videoStreamOptions,
[[maybe_unused]] const AVRational& timeBase,
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
Expand All @@ -219,11 +280,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(

FrameOutput cpuFrameOutput;
cpuInterface->convertAVFrameToFrameOutput(
videoStreamOptions,
timeBase,
avFrame,
cpuFrameOutput,
preAllocatedOutputTensor);
videoStreamOptions, avFrame, cpuFrameOutput, preAllocatedOutputTensor);

frameOutput.data = cpuFrameOutput.data.to(device_);
return;
Expand Down
6 changes: 5 additions & 1 deletion src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ class CudaDeviceInterface : public DeviceInterface {

void initializeContext(AVCodecContext* codecContext) override;

std::unique_ptr<FiltersContext> initializeFiltersContext(
const VideoStreamOptions& videoStreamOptions,
const UniqueAVFrame& avFrame,
const AVRational& timeBase) override;

void convertAVFrameToFrameOutput(
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase,
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor =
Expand Down
14 changes: 13 additions & 1 deletion src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <stdexcept>
#include <string>
#include "FFMPEGCommon.h"
#include "src/torchcodec/_core/FilterGraph.h"
#include "src/torchcodec/_core/Frame.h"
#include "src/torchcodec/_core/StreamOptions.h"

Expand All @@ -33,9 +34,20 @@ class DeviceInterface {
// support CUDA and others only support CPU.
virtual void initializeContext(AVCodecContext* codecContext) = 0;

// Returns FilterContext if device interface can't handle conversion of the
// frame on its own within a call to convertAVFrameToFrameOutput().
// FilterContext contains input and output initialization parameters
// describing required conversion. Output can further be passed to
// convertAVFrameToFrameOutput() to generate output tensor.
virtual std::unique_ptr<FiltersContext> initializeFiltersContext(
[[maybe_unused]] const VideoStreamOptions& videoStreamOptions,
[[maybe_unused]] const UniqueAVFrame& avFrame,
[[maybe_unused]] const AVRational& timeBase) {
return nullptr;
};

virtual void convertAVFrameToFrameOutput(
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase,
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
Expand Down
6 changes: 4 additions & 2 deletions src/torchcodec/_core/FilterGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ FiltersContext::FiltersContext(
int outputHeight,
AVPixelFormat outputFormat,
const std::string& filtergraphStr,
AVRational timeBase)
AVRational timeBase,
AVBufferRef* hwFramesCtx)
: inputWidth(inputWidth),
inputHeight(inputHeight),
inputFormat(inputFormat),
Expand All @@ -31,7 +32,8 @@ FiltersContext::FiltersContext(
outputHeight(outputHeight),
outputFormat(outputFormat),
filtergraphStr(filtergraphStr),
timeBase(timeBase) {}
timeBase(timeBase),
hwFramesCtx(hwFramesCtx) {}

bool operator==(const AVRational& lhs, const AVRational& rhs) {
return lhs.num == rhs.num && lhs.den == rhs.den;
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/FilterGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ struct FiltersContext {
int outputHeight,
AVPixelFormat outputFormat,
const std::string& filtergraphStr,
AVRational timeBase);
AVRational timeBase,
AVBufferRef* hwFramesCtx = nullptr);

bool operator==(const FiltersContext&) const;
bool operator!=(const FiltersContext&) const;
Expand Down
Loading
Loading