Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Tensor Parallel Initialize Distributed Environment
==================================================

This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed.
"""

import logging
Expand All @@ -16,30 +16,66 @@
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh

logger = logging.getLogger(__name__)

def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")
def initialize_logger(
rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO
):
"""Initialize rank-specific Torch-TensorRT logger with configurable handler levels.

Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers

def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
Args:
rank: Process rank for multi-GPU
logger_file_name: Base name for log file (will add _rank.log)
file_level: What goes to file - default DEBUG (everything)
console_level: What prints to console - default INFO (clean output)
"""
logger = logging.getLogger("torch_tensorrt")
logger.setLevel(logging.DEBUG)
logger.handlers.clear()

# File handler
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
fh.setLevel(file_level)
fh.setFormatter(
logging.Formatter(
f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
)
logger.addHandler(fh)

# console handler
ch = logging.StreamHandler()
ch.setLevel(console_level) # Console handler controls what's printed
ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s"))
logger.addHandler(ch)

# safegauard though not reqd
logger.propagate = False
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
def initialize_distributed_env(
logger_file_name,
rank=0,
world_size=1,
port=29500,
file_level="debug",
console_level="info",
):
"""Initialize distributed environment with handler-based logging.

Args:
logger_file_name: Base name for log files
rank: Initial rank (overridden by OMPI env vars)
world_size: Initial world size (overridden by OMPI env vars)
port: Master port for distributed communication
file_level: File handler level - "debug", "info", "warning" (default: "debug")
console_level: Console handler level - "debug", "info", "warning" (default: "info")
"""
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +86,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,12 +99,50 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)

# Convert string handler levels to logging constants
level_map = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
file_level_int = level_map.get(file_level.lower(), logging.DEBUG)
console_level_int = level_map.get(console_level.lower(), logging.INFO)

# Initialize logger with handler-specific levels
# Logger itself is always DEBUG - handlers do the filtering
logger = initialize_logger(
rank,
logger_file_name,
file_level=file_level_int,
console_level=console_level_int,
)

device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

# Set C++ TensorRT runtime log level based on most verbose handler
# this is similar to set_log_level()
cpp_level = min(file_level_int, console_level_int)
try:
import tensorrt as trt
from torch_tensorrt._features import ENABLED_FEATURES

if ENABLED_FEATURES.torch_tensorrt_runtime:
if cpp_level == logging.DEBUG:
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE))
elif cpp_level == logging.INFO:
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO))
elif cpp_level == logging.WARNING:
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING))
else:
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR))
except Exception as e:
logger.warning(f"Could not set C++ TensorRT log level: {e}")

return device_mesh, world_size, rank, logger


Expand Down
23 changes: 15 additions & 8 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@

def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call]
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call]
return out_str


Expand Down Expand Up @@ -163,17 +163,24 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:


def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Runtime check decorator for TensorRT-LLM NCCL plugin availability.

WARNING: This decorator CANNOT prevent registration of converters at import time.
When used with @dynamo_tensorrt_converter, the converter is always registered
regardless of decorator order, because registration happens at import time before
the wrapper is called.

This decorator is kept for potential non-registration use cases where
runtime checks are appropriate.
@apbose: to discuss if this is required
"""

def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.trtllm_for_nccl:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"Refit feature is currently not available in Python 3.13 or higher"
)

return not_implemented(*args, **kwargs)
raise NotImplementedError("TensorRT-LLM plugin for NCCL is not available")

return wrapper

Expand Down
78 changes: 47 additions & 31 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt._features import needs_trtllm_for_nccl
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand All @@ -20,37 +20,53 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


@needs_trtllm_for_nccl
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
# Conditionally register NCCL converters only if TensorRT-LLM plugin is available.
# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because
# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator
# order. Conditional registration prevents registration when TRTLLM is unavailable,
# allowing fallback to PyTorch execution for NCCL ops.

# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator
# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch
if ENABLED_FEATURES.trtllm_for_nccl:
_LOGGER.debug(
"TensorRT-LLM plugin for NCCL is available. Registering NCCL converters."
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def fused_nccl_reduce_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@needs_trtllm_for_nccl
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def fused_nccl_reduce_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
else:
_LOGGER.info(
"TensorRT-LLM plugin for NCCL is not available. "
"NCCL operations will fall back to PyTorch execution."
)
25 changes: 22 additions & 3 deletions tests/py/dynamo/distributed/test_nccl_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,24 @@
from distributed_utils import set_environment_variables_pytest
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt._utils import is_platform_supported_for_trtllm


def is_distributed_nccl_available():
"""
Check if torch.distributed with NCCL backend is available.

Note: torch.distributed is available on Windows but NCCL backend is not.
NCCL (NVIDIA Collective Communications Library) is Linux/Unix only.
This function returns False on Windows, Jetson, and other platforms
where NCCL backend is not supported.
"""
try:
import torch.distributed as dist

# Check if NCCL backend is available (False on Windows, since its gloo. For ORIN some torch distribution it is available
return dist.is_nccl_available()
except (ImportError, AttributeError):
return False


class DistributedGatherModel(nn.Module):
Expand Down Expand Up @@ -42,9 +59,11 @@ def forward(self, x):


class TestNcclOpsConverter(DispatchTestCase):
# 1. Skip if NCCL backend is not available (e.g., Windows, Jetson) - hard requirement
# 2. Don't skip if TRTLLM is unavailable (e.g., CUDA 13) - falls back to PyTorch
@unittest.skipIf(
not is_platform_supported_for_trtllm(),
"Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.",
not is_distributed_nccl_available(),
"Skipped: NCCL backend is not available (Windows/Jetson not supported).",
)
@classmethod
def setUpClass(cls):
Expand Down
Loading