Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
path = packages/on_demand_video_decoder/ext_impl/external/NVTX
url = https://github.com/NVIDIA/NVTX.git

[submodule "packages/optim_test_tools/ext_impl/external/NVTX"]
path = packages/optim_test_tools/ext_impl/external/NVTX
url = https://github.com/NVIDIA/NVTX.git

[submodule "packages/on_demand_video_decoder/ext_impl/external/dlpack"]
path = packages/on_demand_video_decoder/ext_impl/external/dlpack
url = https://github.com/dmlc/dlpack.git
Expand Down
10 changes: 9 additions & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

MAKEFILE_DIR := $(patsubst %/,%,$(abspath $(dir $(lastword $(MAKEFILE_LIST)))))
CURRENT_DIR := $(patsubst %/,%,$(abspath $(CURDIR)))

ifneq ($(CURRENT_DIR),$(MAKEFILE_DIR))
$(error Please run make from $(MAKEFILE_DIR). Example: 'cd $(MAKEFILE_DIR) && make clean html' or 'make -C $(MAKEFILE_DIR) clean html')
endif

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
Expand All @@ -29,9 +36,10 @@ sync-readme:
html: sync-readme generate
@$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

# Clean build directory and generated files
# Clean build directory and generated files (full removal to avoid stale sidebar/toctree)
clean:
@$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
rm -rf $(BUILDDIR)/
rm -rf api/generated/
rm -rf ../packages/*/docs/generated/

Expand Down
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,6 @@ literalinclude
blockquote
distributable
posix
JIT
prepend
prepended
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def backward(ctx: Any, grad: Union[torch.Tensor, None]):
if grad is None:
return None, None, None, None, None
else:
(output_indices, output_nums_indices) = ctx.saved_tensors
output_indices, output_nums_indices = ctx.saved_tensors
grad = grad.contiguous()
grad_input = batched_indexing_access_cuda.forward(grad, output_indices, output_nums_indices, 0.0)
return grad_input, None, None, None, None
Expand Down Expand Up @@ -154,7 +154,7 @@ def backward(ctx: Any, grad: Union[torch.Tensor, None]):
if grad is None:
return None, None, None, None
else:
(output_indices, output_nums_indices) = ctx.saved_tensors
output_indices, output_nums_indices = ctx.saved_tensors
grad = grad.contiguous()
grad_for_to_insert = batched_indexing_access_cuda.forward(
grad, output_indices, output_nums_indices, 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,34 +254,34 @@ static void SavePacketBufferToFile(const uint8_t* packet_buffer, int nVideoBytes
* Reference: ITU-T H.264 Table 7-1
*/
enum H264NalUnitType {
H264_NAL_SLICE = 1, // Coded slice of a non-IDR picture
H264_NAL_DPA = 2, // Coded slice data partition A
H264_NAL_DPB = 3, // Coded slice data partition B
H264_NAL_DPC = 4, // Coded slice data partition C
H264_NAL_IDR_SLICE = 5, // Coded slice of an IDR picture
H264_NAL_SEI = 6, // Supplemental enhancement information
H264_NAL_SPS = 7, // Sequence parameter set
H264_NAL_PPS = 8, // Picture parameter set
H264_NAL_AUD = 9, // Access unit delimiter
H264_NAL_END_SEQUENCE = 10, // End of sequence
H264_NAL_END_STREAM = 11, // End of stream
H264_NAL_FILLER_DATA = 12, // Filler data
H264_NAL_SLICE = 1, // Coded slice of a non-IDR picture
H264_NAL_DPA = 2, // Coded slice data partition A
H264_NAL_DPB = 3, // Coded slice data partition B
H264_NAL_DPC = 4, // Coded slice data partition C
H264_NAL_IDR_SLICE = 5, // Coded slice of an IDR picture
H264_NAL_SEI = 6, // Supplemental enhancement information
H264_NAL_SPS = 7, // Sequence parameter set
H264_NAL_PPS = 8, // Picture parameter set
H264_NAL_AUD = 9, // Access unit delimiter
H264_NAL_END_SEQUENCE = 10, // End of sequence
H264_NAL_END_STREAM = 11, // End of stream
H264_NAL_FILLER_DATA = 12, // Filler data
};

/**
* @brief HEVC/H.265 NAL unit type enumeration
* Reference: ITU-T H.265 Table 7-1
*/
enum HevcNalUnitType {
HEVC_NAL_IDR_W_RADL = 19, // IDR picture with RADL pictures
HEVC_NAL_IDR_N_LP = 20, // IDR picture without leading pictures
HEVC_NAL_CRA_NUT = 21, // Clean random access picture
HEVC_NAL_VPS = 32, // Video parameter set
HEVC_NAL_SPS = 33, // Sequence parameter set
HEVC_NAL_PPS = 34, // Picture parameter set
HEVC_NAL_AUD = 35, // Access unit delimiter
HEVC_NAL_PREFIX_SEI = 39, // Prefix SEI message
HEVC_NAL_SUFFIX_SEI = 40, // Suffix SEI message
HEVC_NAL_IDR_W_RADL = 19, // IDR picture with RADL pictures
HEVC_NAL_IDR_N_LP = 20, // IDR picture without leading pictures
HEVC_NAL_CRA_NUT = 21, // Clean random access picture
HEVC_NAL_VPS = 32, // Video parameter set
HEVC_NAL_SPS = 33, // Sequence parameter set
HEVC_NAL_PPS = 34, // Picture parameter set
HEVC_NAL_AUD = 35, // Access unit delimiter
HEVC_NAL_PREFIX_SEI = 39, // Prefix SEI message
HEVC_NAL_SUFFIX_SEI = 40, // Suffix SEI message
};

/**
Expand All @@ -290,15 +290,15 @@ enum HevcNalUnitType {
* Reference: AV1 Bitstream & Decoding Process Specification
*/
enum AV1ObuType {
OBU_SEQUENCE_HEADER = 1, // Sequence header, appears at key frames
OBU_TEMPORAL_DELIMITER = 2, // Temporal delimiter
OBU_FRAME_HEADER = 3, // Frame header
OBU_TILE_GROUP = 4, // Tile group
OBU_METADATA = 5, // Metadata
OBU_FRAME = 6, // Frame (combined frame header and tile group)
OBU_REDUNDANT_FRAME_HEADER = 7, // Redundant frame header
OBU_TILE_LIST = 8, // Tile list
OBU_PADDING = 15, // Padding
OBU_SEQUENCE_HEADER = 1, // Sequence header, appears at key frames
OBU_TEMPORAL_DELIMITER = 2, // Temporal delimiter
OBU_FRAME_HEADER = 3, // Frame header
OBU_TILE_GROUP = 4, // Tile group
OBU_METADATA = 5, // Metadata
OBU_FRAME = 6, // Frame (combined frame header and tile group)
OBU_REDUNDANT_FRAME_HEADER = 7, // Redundant frame header
OBU_TILE_LIST = 8, // Tile list
OBU_PADDING = 15, // Padding
};

/**
Expand All @@ -318,18 +318,17 @@ inline bool iskeyFrame(AVCodecID codec_id, const uint8_t* pVideo, int demux_flag
uint8_t b = pVideo[2] == 1 ? pVideo[3] : pVideo[4];
int nal_unit_type = b >> 1;
// Check for VPS, SPS, PPS, or SEI NAL units which indicate key frame start
if (nal_unit_type == HEVC_NAL_VPS || nal_unit_type == HEVC_NAL_SPS ||
nal_unit_type == HEVC_NAL_PPS || nal_unit_type == HEVC_NAL_PREFIX_SEI ||
nal_unit_type == HEVC_NAL_SUFFIX_SEI) {
if (nal_unit_type == HEVC_NAL_VPS || nal_unit_type == HEVC_NAL_SPS || nal_unit_type == HEVC_NAL_PPS ||
nal_unit_type == HEVC_NAL_PREFIX_SEI || nal_unit_type == HEVC_NAL_SUFFIX_SEI) {
bPS = true;
}
} else if (codec_id == AV_CODEC_ID_H264) {
uint8_t b = pVideo[2] == 1 ? pVideo[3] : pVideo[4];
int nal_ref_idc = b >> 5;
int nal_unit_type = b & 0x1f;
// Check for SEI, SPS, PPS, or AUD NAL units which indicate key frame start
if (nal_unit_type == H264_NAL_SEI || nal_unit_type == H264_NAL_SPS ||
nal_unit_type == H264_NAL_PPS || nal_unit_type == H264_NAL_AUD) {
if (nal_unit_type == H264_NAL_SEI || nal_unit_type == H264_NAL_SPS || nal_unit_type == H264_NAL_PPS ||
nal_unit_type == H264_NAL_AUD) {
bPS = true;
}
} else if (codec_id == AV_CODEC_ID_AV1) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .nvtx import register_string, range_push, range_pop

__all__ = [
"register_string",
"range_push",
"range_pop",
]
141 changes: 141 additions & 0 deletions packages/optim_test_tools/accvlab/optim_test_tools/numba_nvtx/nvtx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import ctypes

from . import _nvtx_numba_ext as _ext # type: ignore[attr-defined]


_SYMBOLS_READY = False


def _try_register_numba_symbols() -> bool:
"""
Register the extension's C symbols with llvmlite so that Numba ``@njit``
functions can call them. Returns ``False`` if Numba/llvmlite are not
installed (they are optional dependencies).
"""
try:
import llvmlite.binding as llvm
except ImportError:
return False

lib = ctypes.CDLL(_ext.__file__)
push = lib.accvlab_nvtx_range_push
pop = lib.accvlab_nvtx_range_pop
llvm.add_symbol("accvlab_nvtx_range_push", ctypes.cast(push, ctypes.c_void_p).value)
llvm.add_symbol("accvlab_nvtx_range_pop", ctypes.cast(pop, ctypes.c_void_p).value)
return True


_SYMBOLS_READY = _try_register_numba_symbols()


def register_string(name: str) -> int:
"""
Register a string with NVTX once and return an integer handle.

Returns 0 if profiler is not attached (the handle is still safe to pass to
:func:`range_push`, which treats 0 as a no-op).
"""
return int(_ext.register_string(name))


def range_push(handle: int) -> None:
"""
Push an NVTX range using a previously-registered handle.

This function can be called from within Numba ``@njit`` functions.
"""
_ext.range_push(int(handle))


def range_pop() -> None:
"""
Pop an NVTX range.

This function can be called from within Numba ``@njit`` functions.
"""
_ext.range_pop()


# ---------------------- Numba lowering (CPU @njit) ----------------------

try:
from llvmlite import ir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload
except ImportError:
pass
else:

@intrinsic
def _range_push_intrin(typingctx, handle):
sig = types.void(handle)

def codegen(context, builder, signature, args):
i64 = ir.IntType(64)
fnty = ir.FunctionType(ir.VoidType(), [i64])
fn = cgutils.get_or_insert_function(builder.module, fnty, "accvlab_nvtx_range_push")
arg0 = args[0]
if arg0.type != i64:
arg0 = builder.sext(arg0, i64) if arg0.type.width < 64 else builder.trunc(arg0, i64)
builder.call(fn, [arg0])
return context.get_dummy_value()

return sig, codegen

@intrinsic
def _range_pop_intrin(typingctx):
sig = types.void()

def codegen(context, builder, signature, args):
fnty = ir.FunctionType(ir.VoidType(), [])
fn = cgutils.get_or_insert_function(builder.module, fnty, "accvlab_nvtx_range_pop")
builder.call(fn, [])
return context.get_dummy_value()

return sig, codegen

@overload(range_push, inline="always")
def _ov_range_push(handle):
if isinstance(handle, types.Integer):

if not _SYMBOLS_READY:
raise TypingError(
"NVTX C symbols were not registered with llvmlite. "
"This is unexpected — the extension is present but symbol binding failed at import time."
)

def impl(handle):
_range_push_intrin(handle)

return impl
return None

@overload(range_pop, inline="always")
def _ov_range_pop():
if not _SYMBOLS_READY:
raise TypingError(
"NVTX C symbols were not registered with llvmlite. "
"This is unexpected — the extension is present but symbol binding failed at import time."
)

def impl():
_range_pop_intrin()

return impl
Loading