Skip to content

Commit 8662ee1

Browse files
authored
Refactor context structs to use constructors (#869)
1 parent da3edda commit 8662ee1

File tree

4 files changed

+75
-25
lines changed

4 files changed

+75
-25
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ static bool g_cpu = registerDeviceInterface(
1515

1616
} // namespace
1717

18+
CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19+
int inputWidth,
20+
int inputHeight,
21+
AVPixelFormat inputFormat,
22+
int outputWidth,
23+
int outputHeight)
24+
: inputWidth(inputWidth),
25+
inputHeight(inputHeight),
26+
inputFormat(inputFormat),
27+
outputWidth(outputWidth),
28+
outputHeight(outputHeight) {}
29+
1830
bool CpuDeviceInterface::SwsFrameContext::operator==(
1931
const CpuDeviceInterface::SwsFrameContext& other) const {
2032
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
@@ -97,13 +109,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
97109
// And we sometimes re-create them because it's possible for frame
98110
// resolution to change mid-stream. Finally, we want to reuse the colorspace
99111
// conversion objects as much as possible for performance reasons.
100-
SwsFrameContext swsFrameContext;
101-
102-
swsFrameContext.inputWidth = avFrame->width;
103-
swsFrameContext.inputHeight = avFrame->height;
104-
swsFrameContext.inputFormat = frameFormat;
105-
swsFrameContext.outputWidth = expectedOutputWidth;
106-
swsFrameContext.outputHeight = expectedOutputHeight;
112+
SwsFrameContext swsFrameContext(
113+
avFrame->width,
114+
avFrame->height,
115+
frameFormat,
116+
expectedOutputWidth,
117+
expectedOutputHeight);
107118

108119
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
109120
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
@@ -128,22 +139,20 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
128139
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129140
// See comment above in swscale branch about the filterGraphContext_
130141
// creation. creation
131-
FiltersContext filtersContext;
132-
133-
filtersContext.inputWidth = avFrame->width;
134-
filtersContext.inputHeight = avFrame->height;
135-
filtersContext.inputFormat = frameFormat;
136-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137-
filtersContext.outputWidth = expectedOutputWidth;
138-
filtersContext.outputHeight = expectedOutputHeight;
139-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140-
filtersContext.timeBase = timeBase;
141-
142142
std::stringstream filters;
143143
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144144
filters << ":sws_flags=bilinear";
145145

146-
filtersContext.filtergraphStr = filters.str();
146+
FiltersContext filtersContext(
147+
avFrame->width,
148+
avFrame->height,
149+
frameFormat,
150+
avFrame->sample_aspect_ratio,
151+
expectedOutputWidth,
152+
expectedOutputHeight,
153+
AV_PIX_FMT_RGB24,
154+
filters.str(),
155+
timeBase);
147156

148157
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149158
filterGraphContext_ =

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ class CpuDeviceInterface : public DeviceInterface {
4343
const UniqueAVFrame& avFrame);
4444

4545
struct SwsFrameContext {
46-
int inputWidth;
47-
int inputHeight;
48-
AVPixelFormat inputFormat;
49-
int outputWidth;
50-
int outputHeight;
46+
int inputWidth = 0;
47+
int inputHeight = 0;
48+
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
49+
int outputWidth = 0;
50+
int outputHeight = 0;
51+
52+
SwsFrameContext() = default;
53+
SwsFrameContext(
54+
int inputWidth,
55+
int inputHeight,
56+
AVPixelFormat inputFormat,
57+
int outputWidth,
58+
int outputHeight);
5159
bool operator==(const SwsFrameContext&) const;
5260
bool operator!=(const SwsFrameContext&) const;
5361
};

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@ extern "C" {
1313

1414
namespace facebook::torchcodec {
1515

16+
FiltersContext::FiltersContext(
17+
int inputWidth,
18+
int inputHeight,
19+
AVPixelFormat inputFormat,
20+
AVRational inputAspectRatio,
21+
int outputWidth,
22+
int outputHeight,
23+
AVPixelFormat outputFormat,
24+
const std::string& filtergraphStr,
25+
AVRational timeBase)
26+
: inputWidth(inputWidth),
27+
inputHeight(inputHeight),
28+
inputFormat(inputFormat),
29+
inputAspectRatio(inputAspectRatio),
30+
outputWidth(outputWidth),
31+
outputHeight(outputHeight),
32+
outputFormat(outputFormat),
33+
filtergraphStr(filtergraphStr),
34+
timeBase(timeBase) {}
35+
1636
bool operator==(const AVRational& lhs, const AVRational& rhs) {
1737
return lhs.num == rhs.num && lhs.den == rhs.den;
1838
}

src/torchcodec/_core/FilterGraph.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,24 @@ struct FiltersContext {
1919
int outputWidth = 0;
2020
int outputHeight = 0;
2121
AVPixelFormat outputFormat = AV_PIX_FMT_NONE;
22-
2322
std::string filtergraphStr;
2423
AVRational timeBase = {0, 0};
2524
UniqueAVBufferRef hwFramesCtx;
2625

26+
FiltersContext() = default;
27+
FiltersContext(FiltersContext&&) = default;
28+
FiltersContext& operator=(FiltersContext&&) = default;
29+
FiltersContext(
30+
int inputWidth,
31+
int inputHeight,
32+
AVPixelFormat inputFormat,
33+
AVRational inputAspectRatio,
34+
int outputWidth,
35+
int outputHeight,
36+
AVPixelFormat outputFormat,
37+
const std::string& filtergraphStr,
38+
AVRational timeBase);
39+
2740
bool operator==(const FiltersContext&) const;
2841
bool operator!=(const FiltersContext&) const;
2942
};

0 commit comments

Comments
 (0)