From 31666e31d05dbb6f6b4e03db1378d5e1f65d1b81 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 13 Nov 2025 16:03:03 -0800 Subject: [PATCH] Adding rank based logging for torch distributed examples. Also correcting TRT-LLM installation fallback cases --- .../tensor_parallel_initialize_dist.py | 109 +++++++++++++++--- py/torch_tensorrt/_features.py | 23 ++-- .../conversion/custom_ops_converters.py | 78 ++++++++----- tests/py/dynamo/distributed/test_nccl_ops.py | 25 +++- 4 files changed, 174 insertions(+), 61 deletions(-) diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 98d3ca18e9..505009e8e8 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -3,7 +3,7 @@ Tensor Parallel Initialize Distributed Environment ================================================== -This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. +This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed. """ import logging @@ -16,30 +16,66 @@ import torch.distributed as dist from torch.distributed._tensor.device_mesh import init_device_mesh +logger = logging.getLogger(__name__) -def find_repo_root(max_depth=10): - dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): - files = os.listdir(dir_path) - if "MODULE.bazel" in files: - return dir_path - else: - dir_path = os.path.dirname(dir_path) - raise RuntimeError("Could not find repo root") +def initialize_logger( + rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO +): + """Initialize rank-specific Torch-TensorRT logger with configurable handler levels. + Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) + Args: + rank: Process rank for multi-GPU + logger_file_name: Base name for log file (will add _rank.log) + file_level: What goes to file - default DEBUG (everything) + console_level: What prints to console - default INFO (clean output) + """ + logger = logging.getLogger("torch_tensorrt") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + + # File handler fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) + fh.setLevel(file_level) + fh.setFormatter( + logging.Formatter( + f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ) logger.addHandler(fh) + + # console handler + ch = logging.StreamHandler() + ch.setLevel(console_level) # Console handler controls what's printed + ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s")) + logger.addHandler(ch) + + # safegauard though not reqd + logger.propagate = False return logger # This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): +def initialize_distributed_env( + logger_file_name, + rank=0, + world_size=1, + port=29500, + file_level="debug", + console_level="info", +): + """Initialize distributed environment with handler-based logging. + + Args: + logger_file_name: Base name for log files + rank: Initial rank (overridden by OMPI env vars) + world_size: Initial world size (overridden by OMPI env vars) + port: Master port for distributed communication + file_level: File handler level - "debug", "info", "warning" (default: "debug") + console_level: Console handler level - "debug", "info", "warning" (default: "info") + """ local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -50,9 +86,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = ( - find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" - ) # Necessary to assign a device to each rank. torch.cuda.set_device(local_rank) @@ -66,12 +99,50 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) rank = device_mesh.get_rank() assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) + + # Convert string handler levels to logging constants + level_map = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + } + file_level_int = level_map.get(file_level.lower(), logging.DEBUG) + console_level_int = level_map.get(console_level.lower(), logging.INFO) + + # Initialize logger with handler-specific levels + # Logger itself is always DEBUG - handlers do the filtering + logger = initialize_logger( + rank, + logger_file_name, + file_level=file_level_int, + console_level=console_level_int, + ) + device_id = ( rank % torch.cuda.device_count() ) # Ensure each rank gets a unique device torch.cuda.set_device(device_id) + # Set C++ TensorRT runtime log level based on most verbose handler + # this is similar to set_log_level() + cpp_level = min(file_level_int, console_level_int) + try: + import tensorrt as trt + from torch_tensorrt._features import ENABLED_FEATURES + + if ENABLED_FEATURES.torch_tensorrt_runtime: + if cpp_level == logging.DEBUG: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE)) + elif cpp_level == logging.INFO: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO)) + elif cpp_level == logging.WARNING: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING)) + else: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR)) + except Exception as e: + logger.warning(f"Could not set C++ TensorRT log level: {e}") + return device_mesh, world_size, rank, logger diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 03cf4256ec..0fe3014855 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -74,7 +74,7 @@ def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str @@ -163,17 +163,24 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]: + """ + Runtime check decorator for TensorRT-LLM NCCL plugin availability. + + WARNING: This decorator CANNOT prevent registration of converters at import time. + When used with @dynamo_tensorrt_converter, the converter is always registered + regardless of decorator order, because registration happens at import time before + the wrapper is called. + + This decorator is kept for potential non-registration use cases where + runtime checks are appropriate. + @apbose: to discuss if this is required + """ + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: if ENABLED_FEATURES.trtllm_for_nccl: return f(*args, **kwargs) else: - - def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: - raise NotImplementedError( - "Refit feature is currently not available in Python 3.13 or higher" - ) - - return not_implemented(*args, **kwargs) + raise NotImplementedError("TensorRT-LLM plugin for NCCL is not available") return wrapper diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index db14e3528b..302a254f60 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -5,7 +5,7 @@ import tensorrt as trt from torch.fx.node import Argument, Target -from torch_tensorrt._features import needs_trtllm_for_nccl +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -20,37 +20,53 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@needs_trtllm_for_nccl -@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) -def fused_nccl_gather( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_gather( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], +# Conditionally register NCCL converters only if TensorRT-LLM plugin is available. +# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because +# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator +# order. Conditional registration prevents registration when TRTLLM is unavailable, +# allowing fallback to PyTorch execution for NCCL ops. + +# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator +# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch +if ENABLED_FEATURES.trtllm_for_nccl: + _LOGGER.debug( + "TensorRT-LLM plugin for NCCL is available. Registering NCCL converters." ) + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_gather( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_reduce_scatter( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) -@needs_trtllm_for_nccl -@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) -def fused_nccl_reduce_scatter( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_reduce_scatter( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], +else: + _LOGGER.info( + "TensorRT-LLM plugin for NCCL is not available. " + "NCCL operations will fall back to PyTorch execution." ) diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index eafe16d455..5058bb24d0 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -8,7 +8,24 @@ from distributed_utils import set_environment_variables_pytest from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt._utils import is_platform_supported_for_trtllm + + +def is_distributed_nccl_available(): + """ + Check if torch.distributed with NCCL backend is available. + + Note: torch.distributed is available on Windows but NCCL backend is not. + NCCL (NVIDIA Collective Communications Library) is Linux/Unix only. + This function returns False on Windows, Jetson, and other platforms + where NCCL backend is not supported. + """ + try: + import torch.distributed as dist + + # Check if NCCL backend is available (False on Windows, since its gloo. For ORIN some torch distribution it is available + return dist.is_nccl_available() + except (ImportError, AttributeError): + return False class DistributedGatherModel(nn.Module): @@ -42,9 +59,11 @@ def forward(self, x): class TestNcclOpsConverter(DispatchTestCase): + # 1. Skip if NCCL backend is not available (e.g., Windows, Jetson) - hard requirement + # 2. Don't skip if TRTLLM is unavailable (e.g., CUDA 13) - falls back to PyTorch @unittest.skipIf( - not is_platform_supported_for_trtllm(), - "Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.", + not is_distributed_nccl_available(), + "Skipped: NCCL backend is not available (Windows/Jetson not supported).", ) @classmethod def setUpClass(cls):