From 63c3e63ba7d3ab3ac540d817d2b59a53e36e2b4c Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 13 Feb 2025 12:10:45 -0800 Subject: [PATCH 01/10] TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile TRT-LLM installation utilities and adding test cases adding the option in _compiler.py changes in the TRT-LLM loading tool- removing install_wget, install_unzip, install_mpi Further changes in error logging of the TRT-LLM installation tool moving the load_tensorrt_llm to dynamo/utils.py correcting misprint for TRT LLM load Using python lib for download to make it platform agnostic dll file path update for windows correcting the non critical lint error Including version in versions.txt --- dev_dep_versions.yml | 1 + py/torch_tensorrt/dynamo/_compiler.py | 12 ++ .../dynamo/conversion/converter_utils.py | 67 +-------- .../conversion/custom_ops_converters.py | 2 +- py/torch_tensorrt/dynamo/utils.py | 130 +++++++++++++++++- setup.py | 4 + 6 files changed, 149 insertions(+), 67 deletions(-) diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index c9a738feb6..c57a2d8d9e 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,2 +1,3 @@ __cuda_version__: "12.8" __tensorrt_version__: "10.12.0" +__tensorrt_llm_version__: "0.17.0.post1" diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ff7d3b7a07..9ed361082d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -103,6 +103,7 @@ def cross_compile_for_windows( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -177,6 +178,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -339,6 +341,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } # disable the following settings is not supported for cross compilation for windows feature @@ -439,6 +442,7 @@ def compile( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -515,7 +519,11 @@ def compile( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). +<<<<<<< HEAD offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. +======= + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model +>>>>>>> c3b62d239 (TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile) **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -688,6 +696,7 @@ def compile( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) @@ -1052,6 +1061,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1116,6 +1126,7 @@ def convert_exported_program_to_serialized_trt_engine( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1238,6 +1249,7 @@ def convert_exported_program_to_serialized_trt_engine( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 896bf37b42..c988ea9759 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,8 +1,6 @@ import collections -import ctypes import functools import logging -import os from typing import ( Any, Callable, @@ -25,6 +23,7 @@ import torch_tensorrt.dynamo.conversion.impl as impl from torch_tensorrt import _enums +from torch_tensorrt._enums import Platform from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -1117,69 +1116,6 @@ def args_bounds_check( return args[i] if len(args) > i and args[i] is not None else replacement -def load_tensorrt_llm() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", - ) - return False - - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False - - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False - - def promote_trt_tensors_to_same_dtype( ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str ) -> tuple[TRTTensor, TRTTensor]: @@ -1217,3 +1153,4 @@ def promote_trt_tensors_to_same_dtype( rhs_cast = cast_trt_tensor(ctx, rhs, promoted_dtype, f"{name_prefix}rhs_cast") return lhs_cast, rhs_cast + diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 79611c7552..3e67457e54 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -11,11 +11,11 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, ) +from torch_tensorrt.dynamo.utils import load_tensorrt_llm _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0703fd1cb9..55a99c67bf 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,7 +1,10 @@ from __future__ import annotations +import ctypes import gc import logging +import os +import urllib.request import warnings from dataclasses import fields, replace from enum import Enum @@ -14,9 +17,10 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device -from torch_tensorrt._enums import dtype +from torch_tensorrt._enums import Platform, dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt._version import __tensorrt_llm_version__ from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._engine_cache import BaseEngineCache @@ -820,3 +824,127 @@ def is_tegra_platform() -> bool: if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: return True return False + + +def download_plugin_lib_path(py_version: str, platform: str) -> str: + plugin_lib_path = None + + # Downloading TRT-LLM lib + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl" + download_url = base_url + file_name + if not (os.path.exists(file_name)): + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, file_name) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + # Proceeding with the unzip of the wheel file + # This will exist if the filename was already downloaded + if "linux" in platform: + lib_filename = "libnvinfer_plugin_tensorrt_llm.so" + else: + lib_filename = "libnvinfer_plugin_tensorrt_llm.dll" + plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename) + if os.path.exists(plugin_lib_path): + return plugin_lib_path + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm' + plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename + return plugin_lib_path + + +def load_tensorrt_llm() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if not plugin_lib_path: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + else: + # this is used as the default py version + py_version = "cp310" + platform = Platform.current_platform() + + platform = str(platform).lower() + plugin_lib_path = download_plugin_lib_path(py_version, platform) + + try: + # Load the shared TRT-LLM file + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"The dependency libmpi.so is missing. " + f"Please install the packages libmpich-dev and libopenmpi-dev.", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" + f"Ensure the path is correct and the library is compatible", + exc_info=e_os_error, + ) + return False + + try: + # Configure plugin initialization arguments + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + # Initialize the plugin + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" + if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False diff --git a/setup.py b/setup.py index f829602f1a..fb70b70a50 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ __version__: str = "0.0.0" __cuda_version__: str = "0.0" __tensorrt_version__: str = "0.0" +__tensorrt_llm_version__: str = "0.0" LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") @@ -63,6 +64,7 @@ def get_base_version() -> str: def load_dep_info(): global __cuda_version__ global __tensorrt_version__ + global __tensorrt_llm_version__ with open("dev_dep_versions.yml", "r") as stream: versions = yaml.safe_load(stream) if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None: @@ -72,6 +74,7 @@ def load_dep_info(): else: __cuda_version__ = versions["__cuda_version__"] __tensorrt_version__ = versions["__tensorrt_version__"] + __tensorrt_llm_version__ = versions["__tensorrt_llm_version__"] load_dep_info() @@ -240,6 +243,7 @@ def gen_version_file(): f.write('__version__ = "' + __version__ + '"\n') f.write('__cuda_version__ = "' + __cuda_version__ + '"\n') f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') + f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n') def copy_libtorchtrt(multilinux=False, rt_only=False): From af987c22519a80ac869bbd749b144c5814f81e0c Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 20 May 2025 12:42:44 -0700 Subject: [PATCH 02/10] linting error fixes and rebase fix --- py/torch_tensorrt/dynamo/_compiler.py | 3 --- py/torch_tensorrt/dynamo/conversion/converter_utils.py | 1 - 2 files changed, 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9ed361082d..1d44b49874 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -519,11 +519,8 @@ def compile( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). -<<<<<<< HEAD offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. -======= use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model ->>>>>>> c3b62d239 (TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile) **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index c988ea9759..35a64a5b54 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1153,4 +1153,3 @@ def promote_trt_tensors_to_same_dtype( rhs_cast = cast_trt_tensor(ctx, rhs, promoted_dtype, f"{name_prefix}rhs_cast") return lhs_cast, rhs_cast - From 3b2cab966213706c01f02372e8fa48deb4ae1fa7 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 20 May 2025 12:56:57 -0700 Subject: [PATCH 03/10] removing Platform enum from converter_utils.py --- py/torch_tensorrt/dynamo/conversion/converter_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 35a64a5b54..1ca1b33caf 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -23,7 +23,6 @@ import torch_tensorrt.dynamo.conversion.impl as impl from torch_tensorrt import _enums -from torch_tensorrt._enums import Platform from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext From 55917b9c0621b5fddfbf2d37f62a20f7bc045401 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 13 Jun 2025 14:51:45 -0700 Subject: [PATCH 04/10] Addressing review comments- tmp dir for wheel download and wheel extraction, variable for py_version --- py/torch_tensorrt/dynamo/utils.py | 194 +++++++++++++------ tests/py/dynamo/distributed/test_nccl_ops.py | 1 + 2 files changed, 136 insertions(+), 59 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 55a99c67bf..8860f5fdaf 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -2,13 +2,27 @@ import ctypes import gc +import getpass import logging import os +import tempfile import urllib.request import warnings +from contextlib import contextmanager from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import sympy @@ -37,6 +51,7 @@ RTOL = 5e-3 ATOL = 5e-3 CPU_DEVICE = "cpu" +_WHL_CPYTHON_VERSION = "cp310" class Frameworks(Enum): @@ -240,6 +255,19 @@ def set_log_level(parent_logger: Any, level: Any) -> None: """ if parent_logger: parent_logger.setLevel(level) + print("Handlers for parent_logger:", parent_logger.handlers) + print("bool check--", parent_logger.hasHandlers()) + if parent_logger.hasHandlers(): + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) # Allow debug messages on handler + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + parent_logger.addHandler(ch) + print("Logger level:", parent_logger.level) + # print("Parent logger level:", logger.parent.level) + print("Root logger level:", logging.getLogger().level) if ENABLED_FEATURES.torch_tensorrt_runtime: if level == logging.DEBUG: @@ -826,17 +854,41 @@ def is_tegra_platform() -> bool: return False -def download_plugin_lib_path(py_version: str, platform: str) -> str: - plugin_lib_path = None +@contextmanager +def download_plugin_lib_path(platform: str) -> Iterator[str]: + """ + Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform, + then yields the path to the extracted shared library (.so or .dll). - # Downloading TRT-LLM lib - base_url = "https://pypi.nvidia.com/tensorrt-llm/" - file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl" - download_url = base_url + file_name - if not (os.path.exists(file_name)): + The wheel file is cached in a user-specific temporary directory to avoid repeated downloads. + Extraction happens in a temporary directory that is cleaned up after use. + + Args: + platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel. + + Yields: + str: The full path to the extracted TensorRT-LLM shared library file. + + Raises: + ImportError: If the 'zipfile' module is not available. + RuntimeError: If the wheel file is missing, corrupted, or extraction fails. + """ + plugin_lib_path = None + username = getpass.getuser() + torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" + torchtrt_cache_dir.mkdir(parents=True, exist_ok=True) + file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl" + torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name + downloaded_file_path = torchtrt_cache_trtllm_whl + + if not torchtrt_cache_trtllm_whl.exists(): + # Downloading TRT-LLM lib + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + download_url = base_url + file_name + print("Downloading TRT-LLM wheel") try: logger.debug(f"Downloading {download_url} ...") - urllib.request.urlretrieve(download_url, file_name) + urllib.request.urlretrieve(download_url, downloaded_file_path) logger.debug("Download succeeded and TRT-LLM wheel is now present") except urllib.error.HTTPError as e: logger.error( @@ -849,60 +901,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str: except OSError as e: logger.error(f"Local file write error: {e}") - # Proceeding with the unzip of the wheel file - # This will exist if the filename was already downloaded + # Proceeding with the unzip of the wheel file in tmpdir if "linux" in platform: lib_filename = "libnvinfer_plugin_tensorrt_llm.so" else: lib_filename = "libnvinfer_plugin_tensorrt_llm.dll" - plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename) - if os.path.exists(plugin_lib_path): - return plugin_lib_path - try: - import zipfile - except ImportError as e: - raise ImportError( - "zipfile module is required but not found. Please install zipfile" - ) - with zipfile.ZipFile(file_name, "r") as zip_ref: - zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm' - plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename - return plugin_lib_path - -def load_tensorrt_llm() -> bool: + with tempfile.TemporaryDirectory() as tmpdir: + try: + import zipfile + except ImportError: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref: + zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm' + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {downloaded_file_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {downloaded_file_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename) + yield plugin_lib_path + + +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: """ - Attempts to load the TensorRT-LLM plugin and initialize it. - Either the env variable TRTLLM_PLUGINS_PATH can specify the path - Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + Loads and initializes the TensorRT-LLM plugin from the given shared library path. + + Args: + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. + bool: True if successful, False otherwise. """ - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user - use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( - "1", - "true", - "yes", - "on", - ) - if not use_trtllm_plugin: - logger.warning( - "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" - ) - return False - else: - # this is used as the default py version - py_version = "cp310" - platform = Platform.current_platform() - - platform = str(platform).lower() - plugin_lib_path = download_plugin_lib_path(py_version, platform) - try: - # Load the shared TRT-LLM file handle = ctypes.CDLL(plugin_lib_path) logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") except OSError as e_os_error: @@ -915,14 +960,13 @@ def load_tensorrt_llm() -> bool: ) else: logger.warning( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", exc_info=e_os_error, ) return False try: - # Configure plugin initialization arguments handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] handle.initTrtLlmPlugins.restype = ctypes.c_bool except AttributeError as e_plugin_unavailable: @@ -933,9 +977,7 @@ def load_tensorrt_llm() -> bool: return False try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): logger.info("TensorRT-LLM plugin successfully initialized") return True else: @@ -948,3 +990,37 @@ def load_tensorrt_llm() -> bool: ) return False return False + + +def load_tensorrt_llm() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + else: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + else: + platform = Platform.current_platform() + platform = str(platform).lower() + + with download_plugin_lib_path(platform) as plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + return False diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 89c94300b7..a71fd1edc4 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -70,6 +70,7 @@ def forward(self, x): use_dynamo_tracer=True, enable_passes=True, ) + dist.destroy_process_group() if __name__ == "__main__": From b9493609826e2b7355c27b0106ad6aac7eb33134 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 1 Jul 2025 11:13:08 -0700 Subject: [PATCH 05/10] checks for windows where NCCL backend is not supported --- py/torch_tensorrt/dynamo/utils.py | 28 ++++++-------------- tests/py/dynamo/distributed/test_nccl_ops.py | 6 +++++ 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 8860f5fdaf..a6f2a90e53 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -255,19 +255,6 @@ def set_log_level(parent_logger: Any, level: Any) -> None: """ if parent_logger: parent_logger.setLevel(level) - print("Handlers for parent_logger:", parent_logger.handlers) - print("bool check--", parent_logger.hasHandlers()) - if parent_logger.hasHandlers(): - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) # Allow debug messages on handler - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - ch.setFormatter(formatter) - parent_logger.addHandler(ch) - print("Logger level:", parent_logger.level) - # print("Parent logger level:", logger.parent.level) - print("Root logger level:", logging.getLogger().level) if ENABLED_FEATURES.torch_tensorrt_runtime: if level == logging.DEBUG: @@ -885,7 +872,6 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]: # Downloading TRT-LLM lib base_url = "https://pypi.nvidia.com/tensorrt-llm/" download_url = base_url + file_name - print("Downloading TRT-LLM wheel") try: logger.debug(f"Downloading {download_url} ...") urllib.request.urlretrieve(download_url, downloaded_file_path) @@ -937,7 +923,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]: yield plugin_lib_path -def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: +def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bool: """ Loads and initializes the TensorRT-LLM plugin from the given shared library path. @@ -947,6 +933,9 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: Returns: bool: True if successful, False otherwise. """ + if "windows" in platform: + logger.info("NCCL backend is not supported on Windows") + return False try: handle = ctypes.CDLL(plugin_lib_path) logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") @@ -1002,8 +991,10 @@ def load_tensorrt_llm() -> bool: bool: True if the plugin was successfully loaded and initialized, False otherwise. """ plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + platform = Platform.current_platform() + platform = str(platform).lower() if plugin_lib_path: - return load_and_initialize_trtllm_plugin(plugin_lib_path) + return load_and_initialize_trtllm_plugin(plugin_lib_path, platform) else: # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( @@ -1017,10 +1008,7 @@ def load_tensorrt_llm() -> bool: "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" ) return False - else: - platform = Platform.current_platform() - platform = str(platform).lower() with download_plugin_lib_path(platform) as plugin_lib_path: - return load_and_initialize_trtllm_plugin(plugin_lib_path) + return load_and_initialize_trtllm_plugin(plugin_lib_path, platform) return False diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index a71fd1edc4..abde5d8b76 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -6,6 +6,7 @@ from distributed_utils import set_environment_variables_pytest from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._enums import Platform set_environment_variables_pytest() dist.init_process_group(backend="nccl", init_method="env://") @@ -15,7 +16,12 @@ from conversion.harness import DispatchTestCase +platform_str = str(Platform.current_platform()).lower() + +@unittest.skipIf( + "win" in platform_str, "Skipped on Windows: NCCL backend is not supported." +) class TestGatherNcclOpsConverter(DispatchTestCase): @parameterized.expand([8]) def test_nccl_ops(self, linear_layer_dim): From af2aa2ca3a81cec62566d63df39ddd3bee5484b9 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 1 Jul 2025 15:20:56 -0700 Subject: [PATCH 06/10] adding checks for windows and jetson devices --- .../conversion/custom_ops_converters.py | 4 +- py/torch_tensorrt/dynamo/utils.py | 41 +++++++++++++++---- tests/py/dynamo/distributed/test_nccl_ops.py | 14 +++++-- 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 3e67457e54..aecc99b1f1 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -15,11 +15,11 @@ tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, ) -from torch_tensorrt.dynamo.utils import load_tensorrt_llm +from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl _LOGGER: logging.Logger = logging.getLogger(__name__) -if load_tensorrt_llm(): +if load_tensorrt_llm_for_nccl(): @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) def fused_nccl_gather( diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index a6f2a90e53..700dce116e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -841,6 +841,29 @@ def is_tegra_platform() -> bool: return False +def is_platform_supported_for_trtllm(platform: str) -> bool: + """ + Checks if the current platform supports TensorRT-LLM plugins for NCCL backend + Returns: + bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise. + Note: + TensorRT-LLM plugins for NCCL backend are not supported on: + - Windows platforms + - Jetson devices (aarch64 architecture) + """ + if "windows" in platform: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Windows" + ) + return False + if "aarch64" in platform: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)" + ) + return False + return True + + @contextmanager def download_plugin_lib_path(platform: str) -> Iterator[str]: """ @@ -891,6 +914,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]: if "linux" in platform: lib_filename = "libnvinfer_plugin_tensorrt_llm.so" else: + # This condition is never met though lib_filename = "libnvinfer_plugin_tensorrt_llm.dll" with tempfile.TemporaryDirectory() as tmpdir: @@ -923,7 +947,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]: yield plugin_lib_path -def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bool: +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: """ Loads and initializes the TensorRT-LLM plugin from the given shared library path. @@ -933,9 +957,6 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo Returns: bool: True if successful, False otherwise. """ - if "windows" in platform: - logger.info("NCCL backend is not supported on Windows") - return False try: handle = ctypes.CDLL(plugin_lib_path) logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") @@ -981,7 +1002,7 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo return False -def load_tensorrt_llm() -> bool: +def load_tensorrt_llm_for_nccl() -> bool: """ Attempts to load the TensorRT-LLM plugin and initialize it. Either the env variable TRTLLM_PLUGINS_PATH can specify the path @@ -990,11 +1011,15 @@ def load_tensorrt_llm() -> bool: Returns: bool: True if the plugin was successfully loaded and initialized, False otherwise. """ - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + # Check platform compatibility first platform = Platform.current_platform() platform = str(platform).lower() + if not is_platform_supported_for_trtllm(platform): + return False + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if plugin_lib_path: - return load_and_initialize_trtllm_plugin(plugin_lib_path, platform) + return load_and_initialize_trtllm_plugin(plugin_lib_path) else: # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( @@ -1010,5 +1035,5 @@ def load_tensorrt_llm() -> bool: return False with download_plugin_lib_path(platform) as plugin_lib_path: - return load_and_initialize_trtllm_plugin(plugin_lib_path, platform) + return load_and_initialize_trtllm_plugin(plugin_lib_path) return False diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index abde5d8b76..9ae6a03839 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -1,4 +1,5 @@ import os +import unittest import torch import torch.distributed as dist @@ -19,12 +20,13 @@ platform_str = str(Platform.current_platform()).lower() -@unittest.skipIf( - "win" in platform_str, "Skipped on Windows: NCCL backend is not supported." -) class TestGatherNcclOpsConverter(DispatchTestCase): + @unittest.skipIf( + "win" or "aarch64" in platform_str, + "Skipped on Windows and Jetson: NCCL backend is not supported.", + ) @parameterized.expand([8]) - def test_nccl_ops(self, linear_layer_dim): + def test_nccl_ops_gather(self, linear_layer_dim): class DistributedGatherModel(nn.Module): def __init__(self, input_dim): super().__init__() @@ -48,6 +50,10 @@ def forward(self, x): enable_passes=True, ) + @unittest.skipIf( + "win" or "aarch64" in platform_str, + "Skipped on Windows and Jetson: NCCL backend is not supported.", + ) @parameterized.expand([8]) def test_nccl_ops_scatter(self, linear_layer_dim): From 85730982c229e378ce629ce092ed7b1d1b601d30 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 3 Jul 2025 17:35:01 -0700 Subject: [PATCH 07/10] Keeping the extracted and deleting download file, restructuring test --- py/torch_tensorrt/dynamo/utils.py | 142 ++++++++++-------- .../dynamo/distributed/distributed_utils.py | 1 - tests/py/dynamo/distributed/test_nccl_ops.py | 103 +++++++------ tests/py/dynamo/distributed/test_nccl_ops.sh | 47 +----- 4 files changed, 135 insertions(+), 158 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 700dce116e..b30855b599 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -8,7 +8,6 @@ import tempfile import urllib.request import warnings -from contextlib import contextmanager from dataclasses import fields, replace from enum import Enum from pathlib import Path @@ -16,7 +15,6 @@ Any, Callable, Dict, - Iterator, List, Optional, Sequence, @@ -864,40 +862,52 @@ def is_platform_supported_for_trtllm(platform: str) -> bool: return True -@contextmanager -def download_plugin_lib_path(platform: str) -> Iterator[str]: - """ - Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform, - then yields the path to the extracted shared library (.so or .dll). +def _cache_root() -> Path: + username = getpass.getuser() + return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" - The wheel file is cached in a user-specific temporary directory to avoid repeated downloads. - Extraction happens in a temporary directory that is cleaned up after use. - Args: - platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel. +def _extracted_dir_trtllm(platform: str) -> Path: + return _cache_root() / "trtllm" / f"{__tensorrt_llm_version__}_{platform}" - Yields: - str: The full path to the extracted TensorRT-LLM shared library file. - Raises: - ImportError: If the 'zipfile' module is not available. - RuntimeError: If the wheel file is missing, corrupted, or extraction fails. +def download_and_get_plugin_lib_path(platform: str) -> Optional[str]: """ - plugin_lib_path = None - username = getpass.getuser() - torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" - torchtrt_cache_dir.mkdir(parents=True, exist_ok=True) - file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl" - torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name - downloaded_file_path = torchtrt_cache_trtllm_whl - - if not torchtrt_cache_trtllm_whl.exists(): - # Downloading TRT-LLM lib + Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. + + Args: + platform (str): Platform identifier (e.g., 'linux_x86_64') + + Returns: + Optional[str]: Path to shared library or None if operation fails. + """ + wheel_filename = ( + f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" + f"{_WHL_CPYTHON_VERSION}-{platform}.whl" + ) + wheel_path = _cache_root() / wheel_filename + extract_dir = _extracted_dir_trtllm(platform) + # else will never be met though + lib_filename = ( + "libnvinfer_plugin_tensorrt_llm.so" + if "linux" in platform + else "libnvinfer_plugin_tensorrt_llm.dll" + ) + # eg: /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so + plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename + + if plugin_lib_path.exists(): + return str(plugin_lib_path) + + wheel_path.parent.mkdir(parents=True, exist_ok=True) + extract_dir.mkdir(parents=True, exist_ok=True) + + if not wheel_path.exists(): base_url = "https://pypi.nvidia.com/tensorrt-llm/" - download_url = base_url + file_name + download_url = base_url + wheel_filename try: logger.debug(f"Downloading {download_url} ...") - urllib.request.urlretrieve(download_url, downloaded_file_path) + urllib.request.urlretrieve(download_url, wheel_path) logger.debug("Download succeeded and TRT-LLM wheel is now present") except urllib.error.HTTPError as e: logger.error( @@ -910,41 +920,45 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]: except OSError as e: logger.error(f"Local file write error: {e}") - # Proceeding with the unzip of the wheel file in tmpdir - if "linux" in platform: - lib_filename = "libnvinfer_plugin_tensorrt_llm.so" - else: - # This condition is never met though - lib_filename = "libnvinfer_plugin_tensorrt_llm.dll" + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"Extracted wheel to {extract_dir}") + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e - with tempfile.TemporaryDirectory() as tmpdir: - try: - import zipfile - except ImportError: - raise ImportError( - "zipfile module is required but not found. Please install zipfile" - ) - try: - with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref: - zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm' - except FileNotFoundError as e: - # This should capture the errors in the download failure above - logger.error(f"Wheel file not found at {downloaded_file_path}: {e}") - raise RuntimeError( - f"Failed to find downloaded wheel file at {downloaded_file_path}" - ) from e - except zipfile.BadZipFile as e: - logger.error(f"Invalid or corrupted wheel file: {e}") - raise RuntimeError( - "Downloaded wheel file is corrupted or not a valid zip archive" - ) from e - except Exception as e: - logger.error(f"Unexpected error while extracting wheel: {e}") - raise RuntimeError( - "Unexpected error during extraction of TensorRT-LLM wheel" - ) from e - plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename) - yield plugin_lib_path + try: + wheel_path.unlink(missing_ok=True) + logger.debug(f"Deleted wheel file: {wheel_path}") + except Exception as e: + logger.warning(f"Could not delete wheel file {wheel_path}: {e}") + if not plugin_lib_path.exists(): + logger.error( + f"Plugin library not found at expected location: {plugin_lib_path}" + ) + return None + + return str(plugin_lib_path) def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: @@ -1034,6 +1048,6 @@ def load_tensorrt_llm_for_nccl() -> bool: ) return False - with download_plugin_lib_path(platform) as plugin_lib_path: - return load_and_initialize_trtllm_plugin(plugin_lib_path) + plugin_lib_path = download_and_get_plugin_lib_path(platform) + return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] return False diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index e3062249fa..bc058aaaec 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -13,7 +13,6 @@ def set_environment_variables_pytest(): os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(29500) - os.environ["USE_TRTLLM_PLUGINS"] = "1" def initialize_logger(rank, logger_file_name): diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 9ae6a03839..91bcc56f44 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -4,18 +4,42 @@ import torch import torch.distributed as dist import torch.nn as nn +from conversion.harness import DispatchTestCase from distributed_utils import set_environment_variables_pytest from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt._enums import Platform -set_environment_variables_pytest() -dist.init_process_group(backend="nccl", init_method="env://") -group = dist.new_group(ranks=[0]) -group_name = group.group_name -world_size = 1 -from conversion.harness import DispatchTestCase +class DistributedGatherModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(gathered_tensor) + + +class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + out = torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(out) + platform_str = str(Platform.current_platform()).lower() @@ -25,64 +49,49 @@ class TestGatherNcclOpsConverter(DispatchTestCase): "win" or "aarch64" in platform_str, "Skipped on Windows and Jetson: NCCL backend is not supported.", ) + @classmethod + def setUpClass(cls): + set_environment_variables_pytest() + print("USE_TRTLLM_PLUGINS =", os.environ.get("USE_TRTLLM_PLUGINS")) + cls.world_size = 1 + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=cls.world_size, + rank=0, # or read from env + ) + cls.group = dist.new_group(ranks=[0]) + cls.group_name = cls.group.group_name + + @classmethod + def tearDownClass(cls): + if dist.is_initialized(): + dist.destroy_process_group() + @parameterized.expand([8]) def test_nccl_ops_gather(self, linear_layer_dim): - class DistributedGatherModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( - x, world_size, group_name - ) - gathered_tensor = torch.ops._c10d_functional.wait_tensor( - gathered_tensor - ) - return gathered_tensor - inputs = [torch.randn(1, linear_layer_dim).to("cuda")] self.run_test( - DistributedGatherModel(linear_layer_dim).cuda(), + DistributedGatherModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, ) - @unittest.skipIf( - "win" or "aarch64" in platform_str, - "Skipped on Windows and Jetson: NCCL backend is not supported.", - ) @parameterized.expand([8]) def test_nccl_ops_scatter(self, linear_layer_dim): - - class DistributedReduceScatterModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - scatter_reduce_tensor = ( - torch.ops._c10d_functional.reduce_scatter_tensor( - x, "sum", world_size, group_name - ) - ) - scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( - scatter_reduce_tensor - ) - return scatter_reduce_tensor - inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] - self.run_test( - DistributedReduceScatterModel(linear_layer_dim).cuda(), + DistributedReduceScatterModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, ) - dist.destroy_process_group() if __name__ == "__main__": diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh index dd54700048..677d0cb9bc 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.sh +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -70,51 +70,6 @@ ensure_pytest_installed(){ echo "Setting up the environment" -OS="$(uname -s)" -ARCH="$(uname -m)" - - -#getting the file name for TensorRT-LLM download -if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" -elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" -else: - echo "Unsupported platform: OS=$OS ARCH=$ARCH - exit 1 -fi - -# Download the selected file -URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" -echo "Downloading $FILE from $URL..." - -#Installing wget -ensure_installed wget - -#Downloading the file -filename=$(basename "$URL") -if [ -f "$filename" ]; then - echo "File already exists: $filename" -else - wget "$URL" -fi -echo "Download complete: $FILE" - -UNZIP_DIR="tensorrt_llm_unzip" -if [[ ! -d "$UNZIP_DIR" ]]; then - echo "Creating directory: $UNZIP_DIR" - mkdir -p "$UNZIP_DIR" - echo "extracting $FILE to $UNZIP_DIR ..." - #Installing unzip - ensure_installed unzip - #unzip the TensorRT-LLM package - unzip -q "$FILE" -d "$UNZIP_DIR" - echo "Unzip complete" -fi - - -export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" -echo ${TRTLLM_PLUGINS_PATH} ensure_mpi_installed libmpich-dev ensure_mpi_installed libopenmpi-dev @@ -123,7 +78,7 @@ run_tests() { cd .. export PYTHONPATH=$(pwd) echo "Running pytest on distributed/test_nccl_ops.py..." - pytest distributed/test_nccl_ops.py + USE_TRTLLM_PLUGINS=1 pytest distributed/test_nccl_ops.py } run_mpi_tests(){ From 0c46f74816d97da38a8910fdef5dd63eef98020d Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 7 Jul 2025 11:46:36 -0700 Subject: [PATCH 08/10] modifying the error warning of missing libmpi libs --- py/torch_tensorrt/dynamo/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index b30855b599..701c920353 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -977,9 +977,7 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: except OSError as e_os_error: if "libmpi" in str(e_os_error): logger.warning( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " - f"The dependency libmpi.so is missing. " - f"Please install the packages libmpich-dev and libopenmpi-dev.", + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", exc_info=e_os_error, ) else: From 16bf4d1e5cd0cdcf99476cf3943dea21464c1c1b Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 7 Jul 2025 11:54:51 -0700 Subject: [PATCH 09/10] removing the redundant initializations --- tests/py/dynamo/distributed/test_nccl_ops.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 91bcc56f44..e8bca66efe 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -55,12 +55,7 @@ def setUpClass(cls): print("USE_TRTLLM_PLUGINS =", os.environ.get("USE_TRTLLM_PLUGINS")) cls.world_size = 1 if not dist.is_initialized(): - dist.init_process_group( - backend="nccl", - init_method="env://", - world_size=cls.world_size, - rank=0, # or read from env - ) + dist.init_process_group(backend="nccl") cls.group = dist.new_group(ranks=[0]) cls.group_name = cls.group.group_name From eb9da84fd89b91122b885962d434a9c42912c5df Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 17 Jul 2025 12:29:45 -0700 Subject: [PATCH 10/10] adding tests in CI --- .github/workflows/build-test-linux-x86_64.yml | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index 51f3730d02..dc67cba06a 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -337,6 +337,37 @@ jobs: python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . popd + tests-py-distributed: + name: Test dynamo distributed [Python] + needs: [filter-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh + uses: ./.github/workflows/linux-test.yml + with: + job-name: tests-py-dynamo-distributed + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.filter-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + set -euo pipefail + export USE_HOST_DEPS=1 + export CI_BUILD=1 + pushd . + cd tests/py + cd dynamo + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_distributed_test_results.xml distributed/test_nccl_ops.py + popd + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} cancel-in-progress: true