@@ -13,6 +13,34 @@ static bool g_cpu = registerDeviceInterface(
13
13
torch::kCPU ,
14
14
[](const torch::Device& device) { return new CpuDeviceInterface (device); });
15
15
16
+ ColorConversionLibrary getColorConversionLibrary (
17
+ const VideoStreamOptions& videoStreamOptions,
18
+ int width) {
19
+ // By default, we want to use swscale for color conversion because it is
20
+ // faster. However, it has width requirements, so we may need to fall back
21
+ // to filtergraph. We also need to respect what was requested from the
22
+ // options; we respect the options unconditionally, so it's possible for
23
+ // swscale's width requirements to be violated. We don't expose the ability to
24
+ // choose color conversion library publicly; we only use this ability
25
+ // internally.
26
+
27
+ // swscale requires widths to be multiples of 32:
28
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29
+ // so we fall back to filtergraph if the width is not a multiple of 32.
30
+ auto defaultLibrary = (width % 32 == 0 ) ? ColorConversionLibrary::SWSCALE
31
+ : ColorConversionLibrary::FILTERGRAPH;
32
+
33
+ ColorConversionLibrary colorConversionLibrary =
34
+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
35
+
36
+ TORCH_CHECK (
37
+ colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
38
+ colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
39
+ " Invalid color conversion library: " ,
40
+ static_cast <int >(colorConversionLibrary));
41
+ return colorConversionLibrary;
42
+ }
43
+
16
44
} // namespace
17
45
18
46
CpuDeviceInterface::SwsFrameContext::SwsFrameContext (
@@ -46,6 +74,38 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
46
74
device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
47
75
}
48
76
77
+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext (
78
+ const VideoStreamOptions& videoStreamOptions,
79
+ const UniqueAVFrame& avFrame,
80
+ const AVRational& timeBase) {
81
+ enum AVPixelFormat frameFormat =
82
+ static_cast <enum AVPixelFormat>(avFrame->format );
83
+ auto frameDims =
84
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
85
+ int expectedOutputHeight = frameDims.height ;
86
+ int expectedOutputWidth = frameDims.width ;
87
+
88
+ if (getColorConversionLibrary (videoStreamOptions, expectedOutputWidth) ==
89
+ ColorConversionLibrary::SWSCALE) {
90
+ return nullptr ;
91
+ }
92
+
93
+ std::stringstream filters;
94
+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
95
+ filters << " :sws_flags=bilinear" ;
96
+
97
+ return std::make_unique<FiltersContext>(
98
+ avFrame->width ,
99
+ avFrame->height ,
100
+ frameFormat,
101
+ avFrame->sample_aspect_ratio ,
102
+ expectedOutputWidth,
103
+ expectedOutputHeight,
104
+ AV_PIX_FMT_RGB24,
105
+ filters.str (),
106
+ timeBase);
107
+ }
108
+
49
109
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
50
110
// Callers may pass a pre-allocated tensor, where the output.data tensor will
51
111
// be stored. This parameter is honored in any case, but it only leads to a
@@ -57,7 +117,7 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
57
117
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
58
118
void CpuDeviceInterface::convertAVFrameToFrameOutput (
59
119
const VideoStreamOptions& videoStreamOptions,
60
- const AVRational& timeBase,
120
+ [[maybe_unused]] const AVRational& timeBase,
61
121
UniqueAVFrame& avFrame,
62
122
FrameOutput& frameOutput,
63
123
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -83,23 +143,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
83
143
enum AVPixelFormat frameFormat =
84
144
static_cast <enum AVPixelFormat>(avFrame->format );
85
145
86
- // By default, we want to use swscale for color conversion because it is
87
- // faster. However, it has width requirements, so we may need to fall back
88
- // to filtergraph. We also need to respect what was requested from the
89
- // options; we respect the options unconditionally, so it's possible for
90
- // swscale's width requirements to be violated. We don't expose the ability to
91
- // choose color conversion library publicly; we only use this ability
92
- // internally.
93
-
94
- // swscale requires widths to be multiples of 32:
95
- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
96
- // so we fall back to filtergraph if the width is not a multiple of 32.
97
- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
98
- ? ColorConversionLibrary::SWSCALE
99
- : ColorConversionLibrary::FILTERGRAPH;
100
-
101
146
ColorConversionLibrary colorConversionLibrary =
102
- videoStreamOptions. colorConversionLibrary . value_or (defaultLibrary );
147
+ getColorConversionLibrary (videoStreamOptions, expectedOutputWidth );
103
148
104
149
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
105
150
// We need to compare the current frame context with our previous frame
@@ -137,42 +182,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
137
182
138
183
frameOutput.data = outputTensor;
139
184
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
140
- // See comment above in swscale branch about the filterGraphContext_
141
- // creation. creation
142
- std::stringstream filters;
143
- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144
- filters << " :sws_flags=bilinear" ;
185
+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
145
186
146
- FiltersContext filtersContext (
147
- avFrame->width ,
148
- avFrame->height ,
149
- frameFormat,
150
- avFrame->sample_aspect_ratio ,
151
- expectedOutputWidth,
152
- expectedOutputHeight,
153
- AV_PIX_FMT_RGB24,
154
- filters.str (),
155
- timeBase);
156
-
157
- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
158
- filterGraphContext_ =
159
- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160
- prevFiltersContext_ = std::move (filtersContext);
161
- }
162
- outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
163
-
164
- // Similarly to above, if this check fails it means the frame wasn't
165
- // reshaped to its expected dimensions by filtergraph.
166
- auto shape = outputTensor.sizes ();
167
- TORCH_CHECK (
168
- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
169
- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
170
- " Expected output tensor of shape " ,
171
- expectedOutputHeight,
172
- " x" ,
173
- expectedOutputWidth,
174
- " x3, got " ,
175
- shape);
187
+ std::vector<int64_t > shape = {expectedOutputHeight, expectedOutputWidth, 3 };
188
+ std::vector<int64_t > strides = {avFrame->linesize [0 ], 3 , 1 };
189
+ AVFrame* avFramePtr = avFrame.release ();
190
+ auto deleter = [avFramePtr](void *) {
191
+ UniqueAVFrame avFrameToDelete (avFramePtr);
192
+ };
193
+ outputTensor = torch::from_blob (
194
+ avFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
176
195
177
196
if (preAllocatedOutputTensor.has_value ()) {
178
197
// We have already validated that preAllocatedOutputTensor and
@@ -182,11 +201,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
182
201
} else {
183
202
frameOutput.data = outputTensor;
184
203
}
185
- } else {
186
- TORCH_CHECK (
187
- false ,
188
- " Invalid color conversion library: " ,
189
- static_cast <int >(colorConversionLibrary));
190
204
}
191
205
}
192
206
@@ -208,25 +222,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208
222
return resultHeight;
209
223
}
210
224
211
- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
212
- const UniqueAVFrame& avFrame) {
213
- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
214
-
215
- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
216
-
217
- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
218
- int height = frameDims.height ;
219
- int width = frameDims.width ;
220
- std::vector<int64_t > shape = {height, width, 3 };
221
- std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
222
- AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
223
- auto deleter = [filteredAVFramePtr](void *) {
224
- UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
225
- };
226
- return torch::from_blob (
227
- filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
228
- }
229
-
230
225
void CpuDeviceInterface::createSwsContext (
231
226
const SwsFrameContext& swsFrameContext,
232
227
const enum AVColorSpace colorspace) {
0 commit comments