@@ -199,12 +199,127 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199
199
return ;
200
200
}
201
201
202
+ std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext (
203
+ const VideoStreamOptions& videoStreamOptions,
204
+ const UniqueAVFrame& avFrame,
205
+ const AVRational& timeBase) {
206
+ // We need FFmpeg filters to handle those conversion cases which are not
207
+ // directly implemented in CUDA or CPU device interface (in case of a
208
+ // fallback).
209
+ enum AVPixelFormat frameFormat =
210
+ static_cast <enum AVPixelFormat>(avFrame->format );
211
+
212
+ // Input frame is on CPU, we will just pass it to CPU device interface, so
213
+ // skipping filters context as CPU device interface will handle everythong for
214
+ // us.
215
+ if (avFrame->format != AV_PIX_FMT_CUDA) {
216
+ return nullptr ;
217
+ }
218
+
219
+ TORCH_CHECK (
220
+ avFrame->hw_frames_ctx != nullptr ,
221
+ " The AVFrame does not have a hw_frames_ctx. "
222
+ " That's unexpected, please report this to the TorchCodec repo." );
223
+
224
+ auto hwFramesCtx =
225
+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
226
+ AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
227
+
228
+ // NV12 conversion is implemented directly with NPP, no need for filters.
229
+ if (actualFormat == AV_PIX_FMT_NV12) {
230
+ return nullptr ;
231
+ }
232
+
233
+ auto frameDims =
234
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
235
+ int height = frameDims.height ;
236
+ int width = frameDims.width ;
237
+
238
+ AVPixelFormat outputFormat;
239
+ std::stringstream filters;
240
+
241
+ unsigned version_int = avfilter_version ();
242
+ if (version_int < AV_VERSION_INT (8 , 0 , 103 )) {
243
+ // Color conversion support ('format=' option) was added to scale_cuda from
244
+ // n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245
+ // filters. See:
246
+ // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247
+ outputFormat = AV_PIX_FMT_RGB24;
248
+
249
+ auto actualFormatName = av_get_pix_fmt_name (actualFormat);
250
+ TORCH_CHECK (
251
+ actualFormatName != nullptr ,
252
+ " The actual format of a frame is unknown to FFmpeg. "
253
+ " That's unexpected, please report this to the TorchCodec repo." );
254
+
255
+ filters << " hwdownload,format=" << actualFormatName;
256
+ filters << " ,scale=" << width << " :" << height;
257
+ filters << " :sws_flags=bilinear" ;
258
+ } else {
259
+ // Actual output color format will be set via filter options
260
+ outputFormat = AV_PIX_FMT_CUDA;
261
+
262
+ filters << " scale_cuda=" << width << " :" << height;
263
+ filters << " :format=nv12:interp_algo=bilinear" ;
264
+ }
265
+
266
+ return std::make_unique<FiltersContext>(
267
+ avFrame->width ,
268
+ avFrame->height ,
269
+ frameFormat,
270
+ avFrame->sample_aspect_ratio ,
271
+ width,
272
+ height,
273
+ outputFormat,
274
+ filters.str (),
275
+ timeBase,
276
+ av_buffer_ref (avFrame->hw_frames_ctx ));
277
+ }
278
+
202
279
void CudaDeviceInterface::convertAVFrameToFrameOutput (
203
280
const VideoStreamOptions& videoStreamOptions,
204
281
[[maybe_unused]] const AVRational& timeBase,
205
- UniqueAVFrame& avFrame ,
282
+ UniqueAVFrame& avInputFrame ,
206
283
FrameOutput& frameOutput,
207
284
std::optional<torch::Tensor> preAllocatedOutputTensor) {
285
+ std::unique_ptr<FiltersContext> newFiltersContext =
286
+ initializeFiltersContext (videoStreamOptions, avInputFrame, timeBase);
287
+ UniqueAVFrame avFilteredFrame;
288
+ if (newFiltersContext) {
289
+ // We need to compare the current filter context with our previous filter
290
+ // context. If they are different, then we need to re-create a filter
291
+ // graph. We create a filter graph late so that we don't have to depend
292
+ // on the unreliable metadata in the header. And we sometimes re-create
293
+ // it because it's possible for frame resolution to change mid-stream.
294
+ // Finally, we want to reuse the filter graph as much as possible for
295
+ // performance reasons.
296
+ if (!filterGraph_ || *filtersContext_ != *newFiltersContext) {
297
+ filterGraph_ =
298
+ std::make_unique<FilterGraph>(*newFiltersContext, videoStreamOptions);
299
+ filtersContext_ = std::move (newFiltersContext);
300
+ }
301
+ avFilteredFrame = filterGraph_->convert (avInputFrame);
302
+
303
+ // If this check fails it means the frame wasn't
304
+ // reshaped to its expected dimensions by filtergraph.
305
+ TORCH_CHECK (
306
+ (avFilteredFrame->width == filtersContext_->outputWidth ) &&
307
+ (avFilteredFrame->height == filtersContext_->outputHeight ),
308
+ " Expected frame from filter graph of " ,
309
+ filtersContext_->outputWidth ,
310
+ " x" ,
311
+ filtersContext_->outputHeight ,
312
+ " , got " ,
313
+ avFilteredFrame->width ,
314
+ " x" ,
315
+ avFilteredFrame->height );
316
+ }
317
+
318
+ UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame;
319
+
320
+ // The filtered frame might be on CPU if CPU fallback has happenned on filter
321
+ // graph level. For example, that's how we handle color format conversion
322
+ // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
208
323
if (avFrame->format != AV_PIX_FMT_CUDA) {
209
324
// The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
210
325
// the GPU. In this branch, the frame is on the CPU: this is what NVDEC
@@ -232,8 +347,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
232
347
// Above we checked that the AVFrame was on GPU, but that's not enough, we
233
348
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
234
349
// because this is what the NPP color conversion routines expect.
235
- // TODO: we should investigate how to can perform color conversion for
236
- // non-8bit videos. This is supported on CPU.
237
350
TORCH_CHECK (
238
351
avFrame->hw_frames_ctx != nullptr ,
239
352
" The AVFrame does not have a hw_frames_ctx. "
@@ -242,16 +355,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
242
355
auto hwFramesCtx =
243
356
reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
244
357
AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
358
+
245
359
TORCH_CHECK (
246
360
actualFormat == AV_PIX_FMT_NV12,
247
361
" The AVFrame is " ,
248
362
(av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
249
363
: " unknown" ),
250
- " , but we expected AV_PIX_FMT_NV12. This typically happens when "
251
- " the video isn't 8bit, which is not supported on CUDA at the moment. "
252
- " Try using the CPU device instead. "
253
- " If the video is 10bit, we are tracking 10bit support in "
254
- " https://github.com/pytorch/torchcodec/issues/776" );
364
+ " , but we expected AV_PIX_FMT_NV12. "
365
+ " That's unexpected, please report this to the TorchCodec repo." );
255
366
256
367
auto frameDims =
257
368
getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
0 commit comments