Skip to content

Commit 39461cb

Browse files
committed
Implement initializeFiltersContext for CPU device interface
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent 9823338 commit 39461cb

File tree

2 files changed

+78
-84
lines changed

2 files changed

+78
-84
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 71 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@ static bool g_cpu = registerDeviceInterface(
1313
torch::kCPU,
1414
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
1515

16+
ColorConversionLibrary getColorConversionLibrary(
17+
const VideoStreamOptions& videoStreamOptions,
18+
int width) {
19+
// By default, we want to use swscale for color conversion because it is
20+
// faster. However, it has width requirements, so we may need to fall back
21+
// to filtergraph. We also need to respect what was requested from the
22+
// options; we respect the options unconditionally, so it's possible for
23+
// swscale's width requirements to be violated. We don't expose the ability to
24+
// choose color conversion library publicly; we only use this ability
25+
// internally.
26+
27+
// swscale requires widths to be multiples of 32:
28+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+
// so we fall back to filtergraph if the width is not a multiple of 32.
30+
auto defaultLibrary = (width % 32 == 0) ? ColorConversionLibrary::SWSCALE
31+
: ColorConversionLibrary::FILTERGRAPH;
32+
33+
ColorConversionLibrary colorConversionLibrary =
34+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
35+
36+
TORCH_CHECK(
37+
colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
38+
colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
39+
"Invalid color conversion library: ",
40+
static_cast<int>(colorConversionLibrary));
41+
return colorConversionLibrary;
42+
}
43+
1644
} // namespace
1745

1846
CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
@@ -46,6 +74,38 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4674
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
4775
}
4876

77+
std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext(
78+
const VideoStreamOptions& videoStreamOptions,
79+
const UniqueAVFrame& avFrame,
80+
const AVRational& timeBase) {
81+
enum AVPixelFormat frameFormat =
82+
static_cast<enum AVPixelFormat>(avFrame->format);
83+
auto frameDims =
84+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
85+
int expectedOutputHeight = frameDims.height;
86+
int expectedOutputWidth = frameDims.width;
87+
88+
if (getColorConversionLibrary(videoStreamOptions, expectedOutputWidth) ==
89+
ColorConversionLibrary::SWSCALE) {
90+
return nullptr;
91+
}
92+
93+
std::stringstream filters;
94+
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
95+
filters << ":sws_flags=bilinear";
96+
97+
return std::make_unique<FiltersContext>(
98+
avFrame->width,
99+
avFrame->height,
100+
frameFormat,
101+
avFrame->sample_aspect_ratio,
102+
expectedOutputWidth,
103+
expectedOutputHeight,
104+
AV_PIX_FMT_RGB24,
105+
filters.str(),
106+
timeBase);
107+
}
108+
49109
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
50110
// Callers may pass a pre-allocated tensor, where the output.data tensor will
51111
// be stored. This parameter is honored in any case, but it only leads to a
@@ -57,7 +117,7 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
57117
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
58118
void CpuDeviceInterface::convertAVFrameToFrameOutput(
59119
const VideoStreamOptions& videoStreamOptions,
60-
const AVRational& timeBase,
120+
[[maybe_unused]] const AVRational& timeBase,
61121
UniqueAVFrame& avFrame,
62122
FrameOutput& frameOutput,
63123
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -83,23 +143,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
83143
enum AVPixelFormat frameFormat =
84144
static_cast<enum AVPixelFormat>(avFrame->format);
85145

86-
// By default, we want to use swscale for color conversion because it is
87-
// faster. However, it has width requirements, so we may need to fall back
88-
// to filtergraph. We also need to respect what was requested from the
89-
// options; we respect the options unconditionally, so it's possible for
90-
// swscale's width requirements to be violated. We don't expose the ability to
91-
// choose color conversion library publicly; we only use this ability
92-
// internally.
93-
94-
// swscale requires widths to be multiples of 32:
95-
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
96-
// so we fall back to filtergraph if the width is not a multiple of 32.
97-
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
98-
? ColorConversionLibrary::SWSCALE
99-
: ColorConversionLibrary::FILTERGRAPH;
100-
101146
ColorConversionLibrary colorConversionLibrary =
102-
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
147+
getColorConversionLibrary(videoStreamOptions, expectedOutputWidth);
103148

104149
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
105150
// We need to compare the current frame context with our previous frame
@@ -137,42 +182,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
137182

138183
frameOutput.data = outputTensor;
139184
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
140-
// See comment above in swscale branch about the filterGraphContext_
141-
// creation. creation
142-
std::stringstream filters;
143-
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144-
filters << ":sws_flags=bilinear";
185+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
145186

146-
FiltersContext filtersContext(
147-
avFrame->width,
148-
avFrame->height,
149-
frameFormat,
150-
avFrame->sample_aspect_ratio,
151-
expectedOutputWidth,
152-
expectedOutputHeight,
153-
AV_PIX_FMT_RGB24,
154-
filters.str(),
155-
timeBase);
156-
157-
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
158-
filterGraphContext_ =
159-
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160-
prevFiltersContext_ = std::move(filtersContext);
161-
}
162-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
163-
164-
// Similarly to above, if this check fails it means the frame wasn't
165-
// reshaped to its expected dimensions by filtergraph.
166-
auto shape = outputTensor.sizes();
167-
TORCH_CHECK(
168-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
169-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
170-
"Expected output tensor of shape ",
171-
expectedOutputHeight,
172-
"x",
173-
expectedOutputWidth,
174-
"x3, got ",
175-
shape);
187+
std::vector<int64_t> shape = {expectedOutputHeight, expectedOutputWidth, 3};
188+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
189+
AVFrame* avFramePtr = avFrame.release();
190+
auto deleter = [avFramePtr](void*) {
191+
UniqueAVFrame avFrameToDelete(avFramePtr);
192+
};
193+
outputTensor = torch::from_blob(
194+
avFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
176195

177196
if (preAllocatedOutputTensor.has_value()) {
178197
// We have already validated that preAllocatedOutputTensor and
@@ -182,11 +201,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
182201
} else {
183202
frameOutput.data = outputTensor;
184203
}
185-
} else {
186-
TORCH_CHECK(
187-
false,
188-
"Invalid color conversion library: ",
189-
static_cast<int>(colorConversionLibrary));
190204
}
191205
}
192206

@@ -208,25 +222,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208222
return resultHeight;
209223
}
210224

211-
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
212-
const UniqueAVFrame& avFrame) {
213-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
214-
215-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
216-
217-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
218-
int height = frameDims.height;
219-
int width = frameDims.width;
220-
std::vector<int64_t> shape = {height, width, 3};
221-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
222-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
223-
auto deleter = [filteredAVFramePtr](void*) {
224-
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
225-
};
226-
return torch::from_blob(
227-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
228-
}
229-
230225
void CpuDeviceInterface::createSwsContext(
231226
const SwsFrameContext& swsFrameContext,
232227
const enum AVColorSpace colorspace) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class CpuDeviceInterface : public DeviceInterface {
2626
void initializeContext(
2727
[[maybe_unused]] AVCodecContext* codecContext) override {}
2828

29+
std::unique_ptr<FiltersContext> initializeFiltersContext(
30+
const VideoStreamOptions& videoStreamOptions,
31+
const UniqueAVFrame& avFrame,
32+
const AVRational& timeBase) override;
33+
2934
void convertAVFrameToFrameOutput(
3035
const VideoStreamOptions& videoStreamOptions,
3136
const AVRational& timeBase,
@@ -39,9 +44,6 @@ class CpuDeviceInterface : public DeviceInterface {
3944
const UniqueAVFrame& avFrame,
4045
torch::Tensor& outputTensor);
4146

42-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
44-
4547
struct SwsFrameContext {
4648
int inputWidth = 0;
4749
int inputHeight = 0;
@@ -64,15 +66,12 @@ class CpuDeviceInterface : public DeviceInterface {
6466
const SwsFrameContext& swsFrameContext,
6567
const enum AVColorSpace colorspace);
6668

67-
// color-conversion fields. Only one of FilterGraphContext and
68-
// UniqueSwsContext should be non-null.
69-
std::unique_ptr<FilterGraph> filterGraphContext_;
69+
// SWS color conversion context
7070
UniqueSwsContext swsContext_;
7171

72-
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
72+
// Used to know whether a new UniqueSwsContext should
7373
// be created before decoding a new frame.
7474
SwsFrameContext prevSwsFrameContext_;
75-
FiltersContext prevFiltersContext_;
7675
};
7776

7877
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)