Skip to content

Commit 6377dfc

Browse files
authored
BETA CUDA interface: separate it from the default interface (#931)
1 parent ce5667d commit 6377dfc

File tree

7 files changed

+406
-340
lines changed

7 files changed

+406
-340
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,12 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
202202
TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!");
203203
TORCH_CHECK(
204204
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
205+
206+
// Initialize CUDA context with a dummy tensor
207+
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
208+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
209+
210+
nppCtx_ = getNppStreamContext(device_);
205211
}
206212

207213
BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
@@ -222,21 +228,13 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
222228
cuvidDestroyVideoParser(videoParser_);
223229
videoParser_ = nullptr;
224230
}
231+
232+
returnNppStreamContextToCache(device_, std::move(nppCtx_));
225233
}
226234

227235
void BetaCudaDeviceInterface::initialize(
228236
const AVStream* avStream,
229237
const UniqueDecodingAVFormatContext& avFormatCtx) {
230-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
231-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
232-
233-
auto cudaDevice = torch::Device(torch::kCUDA);
234-
defaultCudaInterface_ =
235-
std::unique_ptr<DeviceInterface>(createDeviceInterface(cudaDevice));
236-
AVCodecContext dummyCodecContext = {};
237-
defaultCudaInterface_->initialize(avStream, avFormatCtx);
238-
defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext);
239-
240238
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
241239
timeBase_ = avStream->time_base;
242240
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
@@ -623,15 +621,19 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
623621
UniqueAVFrame& avFrame,
624622
FrameOutput& frameOutput,
625623
std::optional<torch::Tensor> preAllocatedOutputTensor) {
624+
// TODONVDEC P2: we may need to handle 10bit videos the same way the default
625+
// interface does it with maybeConvertAVFrameToNV12OrRGB24().
626626
TORCH_CHECK(
627627
avFrame->format == AV_PIX_FMT_CUDA,
628628
"Expected CUDA format frame from BETA CUDA interface");
629629

630-
// TODONVDEC P1: we use the 'default' cuda device interface for color
631-
// conversion. That's a temporary hack to make things work. we should abstract
632-
// the color conversion stuff separately.
633-
defaultCudaInterface_->convertAVFrameToFrameOutput(
634-
avFrame, frameOutput, preAllocatedOutputTensor);
630+
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);
631+
632+
at::cuda::CUDAStream nvdecStream =
633+
at::cuda::getCurrentCUDAStream(device_.index());
634+
635+
frameOutput.data = convertNV12FrameToRGB(
636+
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
635637
}
636638

637639
} // namespace facebook::torchcodec

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#pragma once
1717

18+
#include "src/torchcodec/_core/CUDACommon.h"
1819
#include "src/torchcodec/_core/Cache.h"
1920
#include "src/torchcodec/_core/DeviceInterface.h"
2021
#include "src/torchcodec/_core/FFMPEGCommon.h"
@@ -94,10 +95,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
9495

9596
UniqueAVBSFContext bitstreamFilter_;
9697

97-
// Default CUDA interface for color conversion.
98-
// TODONVDEC P2: we shouldn't need to keep a separate instance of the default.
99-
// See other TODO there about how interfaces should be completely independent.
100-
std::unique_ptr<DeviceInterface> defaultCudaInterface_;
98+
// NPP context for color conversion
99+
UniqueNppContext nppCtx_;
101100
};
102101

103102
} // namespace facebook::torchcodec

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function(make_torchcodec_libraries
9999
)
100100

101101
if(ENABLE_CUDA)
102-
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp)
102+
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp)
103103
endif()
104104

105105
set(core_library_dependencies
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/CUDACommon.h"
8+
9+
namespace facebook::torchcodec {
10+
11+
namespace {
12+
13+
// Pytorch can only handle up to 128 GPUs.
14+
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
15+
const int MAX_CUDA_GPUS = 128;
16+
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
17+
// Set to a positive number to have a cache of that size.
18+
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
19+
20+
PerGpuCache<NppStreamContext> g_cached_npp_ctxs(
21+
MAX_CUDA_GPUS,
22+
MAX_CONTEXTS_PER_GPU_IN_CACHE);
23+
24+
} // namespace
25+
26+
/* clang-format off */
27+
// Note: [YUV -> RGB Color Conversion, color space and color range]
28+
//
29+
// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV
30+
// format. We need to convert them to RGB. This note attempts to describe this
31+
// process. There may be some inaccuracies and approximations that experts will
32+
// notice, but our goal is only to provide a good enough understanding of the
33+
// process for torchcodec developers to implement and maintain it.
34+
// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have
35+
// to do a lot of the heavy lifting ourselves.
36+
//
37+
// Color space and color range
38+
// ---------------------------
39+
// Two main characteristics of a frame will affect the conversion process:
40+
// 1. Color space: This basically defines what YUV values correspond to which
41+
// physical wavelength. No need to go into details here,the point is that
42+
// videos can come in different color spaces, the most common ones being
43+
// BT.601 and BT.709, but there are others.
44+
// In FFmpeg this is represented with AVColorSpace:
45+
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85
46+
// 2. Color range: This defines the range of YUV values. There is:
47+
// - full range, also called PC range: AVCOL_RANGE_JPEG
48+
// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG
49+
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a
50+
//
51+
// Color space and color range are independent concepts, so we can have a BT.709
52+
// with full range, and another one with limited range. Same for BT.601.
53+
//
54+
// In the first version of this note we'll focus on the full color range. It
55+
// will later be updated to account for the limited range.
56+
//
57+
// Color conversion matrix
58+
// -----------------------
59+
// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV,
60+
// So this is where we'll start.
61+
// At the core of a RGB -> YUV conversion are the "luma coefficients", which are
62+
// specific to a given color space and defined by the color space standard. In
63+
// FFmpeg they can be found here:
64+
// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56
65+
//
66+
// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722
67+
// Coefficients must sum to 1.
68+
//
69+
// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range
70+
// (that's mathematically, in practice they are represented in integer range).
71+
// The conversion is defined as:
72+
// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr
73+
// Y = kr*R + kg*G + kb*B
74+
// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb)
75+
// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr)
76+
//
77+
// Putting all this into matrix form, we get:
78+
// [Y] = [kr kg kb ] [R]
79+
// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G]
80+
// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B]
81+
//
82+
//
83+
// Now, to convert YUV to RGB, we just need to invert this matrix:
84+
// ```py
85+
// import torch
86+
// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients
87+
// u_scale = 2 * (1 - kb)
88+
// v_scale = 2 * (1 - kr)
89+
//
90+
// rgb_to_yuv = torch.tensor([
91+
// [kr, kg, kb],
92+
// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale],
93+
// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale]
94+
// ])
95+
//
96+
// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv)
97+
// print("YUV->RGB matrix (Full Range):")
98+
// print(yuv_to_rgb_full)
99+
// ```
100+
// And we get:
101+
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00],
102+
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01],
103+
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]])
104+
//
105+
// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion
106+
//
107+
// Color conversion in NPP
108+
// -----------------------
109+
// https://docs.nvidia.com/cuda/npp/image_color_conversion.html.
110+
//
111+
// NPP provides different ways to convert YUV to RGB:
112+
// - pre-defined color conversion functions like
113+
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx
114+
// which are for BT.709 limited and full range, respectively.
115+
// - generic color conversion functions that accept a custom color conversion
116+
// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx
117+
//
118+
// We use the pre-defined functions or the color twist functions depending on
119+
// which one we find to be closer to the CPU results.
120+
//
121+
// The color twist functionality is *partially* described in a section named
122+
// "YUVToRGBColorTwist". Importantly:
123+
//
124+
// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data
125+
// and the color-conversion matrix as input. The function itself and the
126+
// matrix assume different ranges for YUV values:
127+
// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in
128+
// [-0.5, 0.5]. That's how we defined our matrix above.
129+
// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all
130+
// of the input Y, U, V to be in [0, 255]. That's how the data comes out of
131+
// the decoder.
132+
// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to
133+
// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128
134+
// offset to U and V. Y doesn't need to be offset. The offset can be applied
135+
// by adding a 4th column to the matrix.
136+
//
137+
//
138+
// So our conversion matrix becomes the following, with new offset column:
139+
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0]
140+
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128]
141+
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]])
142+
//
143+
// And that's what we need to pass for BT701, full range.
144+
/* clang-format on */
145+
146+
// BT.709 full range color conversion matrix for YUV to RGB conversion.
147+
// See Note [YUV -> RGB Color Conversion, color space and color range]
148+
const Npp32f bt709FullRangeColorTwist[3][4] = {
149+
{1.0f, 0.0f, 1.5748f, 0.0f},
150+
{1.0f, -0.187324273f, -0.468124273f, -128.0f},
151+
{1.0f, 1.8556f, 0.0f, -128.0f}};
152+
153+
torch::Tensor convertNV12FrameToRGB(
154+
UniqueAVFrame& avFrame,
155+
const torch::Device& device,
156+
const UniqueNppContext& nppCtx,
157+
at::cuda::CUDAStream nvdecStream,
158+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
159+
auto frameDims = FrameDims(avFrame->height, avFrame->width);
160+
torch::Tensor dst;
161+
if (preAllocatedOutputTensor.has_value()) {
162+
dst = preAllocatedOutputTensor.value();
163+
} else {
164+
dst = allocateEmptyHWCTensor(frameDims, device);
165+
}
166+
167+
// We need to make sure NVDEC has finished decoding a frame before
168+
// color-converting it with NPP.
169+
// So we make the NPP stream wait for NVDEC to finish.
170+
at::cuda::CUDAStream nppStream =
171+
at::cuda::getCurrentCUDAStream(device.index());
172+
at::cuda::CUDAEvent nvdecDoneEvent;
173+
nvdecDoneEvent.record(nvdecStream);
174+
nvdecDoneEvent.block(nppStream);
175+
176+
nppCtx->hStream = nppStream.stream();
177+
cudaError_t err = cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags);
178+
TORCH_CHECK(
179+
err == cudaSuccess,
180+
"cudaStreamGetFlags failed: ",
181+
cudaGetErrorString(err));
182+
183+
NppiSize oSizeROI = {frameDims.width, frameDims.height};
184+
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};
185+
186+
NppStatus status;
187+
188+
// For background, see
189+
// Note [YUV -> RGB Color Conversion, color space and color range]
190+
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
191+
if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) {
192+
// NPP provides a pre-defined color conversion function for BT.709 full
193+
// range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely
194+
// matching the results we have on CPU. So we're using a custom color
195+
// conversion matrix, which provides more accurate results. See the note
196+
// mentioned above for details, and headaches.
197+
198+
int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]};
199+
200+
status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx(
201+
yuvData,
202+
srcStep,
203+
static_cast<Npp8u*>(dst.data_ptr()),
204+
dst.stride(0),
205+
oSizeROI,
206+
bt709FullRangeColorTwist,
207+
*nppCtx);
208+
} else {
209+
// If not full range, we assume studio limited range.
210+
// The color conversion matrix for BT.709 limited range should be:
211+
// static const Npp32f bt709LimitedRangeColorTwist[3][4] = {
212+
// {1.16438356f, 0.0f, 1.79274107f, -16.0f},
213+
// {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f},
214+
// {1.16438356f, 2.11240179f, 0.0f, -128.0f}
215+
// };
216+
// We get very close results to CPU with that, but using the pre-defined
217+
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate.
218+
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
219+
yuvData,
220+
avFrame->linesize[0],
221+
static_cast<Npp8u*>(dst.data_ptr()),
222+
dst.stride(0),
223+
oSizeROI,
224+
*nppCtx);
225+
}
226+
} else {
227+
// TODO we're assuming BT.601 color space (and probably limited range) by
228+
// calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range,
229+
// and other color-spaces like 2020.
230+
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
231+
yuvData,
232+
avFrame->linesize[0],
233+
static_cast<Npp8u*>(dst.data_ptr()),
234+
dst.stride(0),
235+
oSizeROI,
236+
*nppCtx);
237+
}
238+
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
239+
240+
return dst;
241+
}
242+
243+
UniqueNppContext getNppStreamContext(const torch::Device& device) {
244+
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
245+
246+
UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device);
247+
if (nppCtx) {
248+
return nppCtx;
249+
}
250+
251+
// From 12.9, NPP recommends using a user-created NppStreamContext and using
252+
// the `_Ctx()` calls:
253+
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
254+
// And the nppGetStreamContext() helper is deprecated. We are explicitly
255+
// supposed to create the NppStreamContext manually from the CUDA device
256+
// properties:
257+
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
258+
259+
nppCtx = std::make_unique<NppStreamContext>();
260+
cudaDeviceProp prop{};
261+
cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex);
262+
TORCH_CHECK(
263+
err == cudaSuccess,
264+
"cudaGetDeviceProperties failed: ",
265+
cudaGetErrorString(err));
266+
267+
nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
268+
nppCtx->nMultiProcessorCount = prop.multiProcessorCount;
269+
nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
270+
nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
271+
nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock;
272+
nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major;
273+
nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor;
274+
275+
return nppCtx;
276+
}
277+
278+
void returnNppStreamContextToCache(
279+
const torch::Device& device,
280+
UniqueNppContext nppCtx) {
281+
if (nppCtx) {
282+
g_cached_npp_ctxs.addIfCacheHasCapacity(device, std::move(nppCtx));
283+
}
284+
}
285+
286+
void validatePreAllocatedTensorShape(
287+
const std::optional<torch::Tensor>& preAllocatedOutputTensor,
288+
const UniqueAVFrame& avFrame) {
289+
// Note that CUDA does not yet support transforms, so the only possible
290+
// frame dimensions are the raw decoded frame's dimensions.
291+
auto frameDims = FrameDims(avFrame->height, avFrame->width);
292+
293+
if (preAllocatedOutputTensor.has_value()) {
294+
auto shape = preAllocatedOutputTensor.value().sizes();
295+
TORCH_CHECK(
296+
(shape.size() == 3) && (shape[0] == frameDims.height) &&
297+
(shape[1] == frameDims.width) && (shape[2] == 3),
298+
"Expected tensor of shape ",
299+
frameDims.height,
300+
"x",
301+
frameDims.width,
302+
"x3, got ",
303+
shape);
304+
}
305+
}
306+
307+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)