|
| 1 | +import ctypes |
| 2 | +import getpass |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import platform |
1 | 6 | import sys
|
2 |
| -from typing import Any |
| 7 | +import tempfile |
| 8 | +import urllib.request |
| 9 | +from pathlib import Path |
| 10 | +from typing import Any, Optional |
3 | 11 |
|
4 | 12 | import tensorrt as trt
|
5 | 13 | import torch
|
| 14 | +from _version import __tensorrt_llm_version_ |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | +_WHL_CPYTHON_VERSION = "cp310" |
6 | 19 |
|
7 | 20 |
|
8 | 21 | def sanitized_torch_version() -> Any:
|
@@ -50,3 +63,253 @@ def is_tensorrt_version_supported(min_version: str) -> bool:
|
50 | 63 | except (ImportError, ValueError):
|
51 | 64 | # If tensorrt is not installed or version cannot be determined
|
52 | 65 | return False
|
| 66 | + |
| 67 | + |
| 68 | +def is_thor() -> bool: |
| 69 | + if torch.cuda.get_device_capability() in [(11, 0)]: |
| 70 | + return True |
| 71 | + return False |
| 72 | + |
| 73 | + |
| 74 | +def is_platform_supported_for_trtllm() -> bool: |
| 75 | + """ |
| 76 | + Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. |
| 77 | +
|
| 78 | + Returns: |
| 79 | + bool: True if supported, False otherwise. |
| 80 | +
|
| 81 | + Unsupported: |
| 82 | + - Windows platforms |
| 83 | + - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) |
| 84 | + - CUDA 13 not supported |
| 85 | + """ |
| 86 | + system = platform.system().lower() |
| 87 | + machine = platform.machine().lower() |
| 88 | + release = platform.release().lower() |
| 89 | + |
| 90 | + if "windows" in system: |
| 91 | + logger.info( |
| 92 | + "TensorRT-LLM plugins for NCCL backend are not supported on Windows." |
| 93 | + ) |
| 94 | + return False |
| 95 | + |
| 96 | + if machine == "aarch64" and "tegra" in release or is_thor(): |
| 97 | + logger.info( |
| 98 | + "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) or Thor devices." |
| 99 | + ) |
| 100 | + return False |
| 101 | + |
| 102 | + try: |
| 103 | + cuda_version = torch.version.cuda # e.g., "12.4" or "13.0" |
| 104 | + if cuda_version is None: |
| 105 | + logger.error( |
| 106 | + "This pytorch build does not support CUDA, please reinstall pytorch with CUDA support" |
| 107 | + ) |
| 108 | + return False |
| 109 | + |
| 110 | + major, minor = map(int, cuda_version.split(".")) |
| 111 | + if major != 12: |
| 112 | + logger.error( |
| 113 | + "CUDA 13 is not supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support" |
| 114 | + ) |
| 115 | + return False |
| 116 | + |
| 117 | + return True |
| 118 | + |
| 119 | + except Exception as e: |
| 120 | + logger.warning(f"Failed to detect CUDA version: {e}") |
| 121 | + return False |
| 122 | + |
| 123 | + return True |
| 124 | + |
| 125 | + |
| 126 | +def _cache_root() -> Path: |
| 127 | + username = getpass.getuser() |
| 128 | + return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" |
| 129 | + |
| 130 | + |
| 131 | +def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: |
| 132 | + return ( |
| 133 | + _cache_root() |
| 134 | + / "trtllm" |
| 135 | + / f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}" |
| 136 | + ) |
| 137 | + |
| 138 | + |
| 139 | +def download_and_get_plugin_lib_path() -> Optional[str]: |
| 140 | + """ |
| 141 | + Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. |
| 142 | +
|
| 143 | + Args: |
| 144 | + platform (str): Platform identifier (e.g., 'linux_x86_64') |
| 145 | +
|
| 146 | + Returns: |
| 147 | + Optional[str]: Path to shared library or None if operation fails. |
| 148 | + """ |
| 149 | + platform_system = platform.system().lower() |
| 150 | + platform_machine = platform.machine().lower() |
| 151 | + wheel_filename = ( |
| 152 | + f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" |
| 153 | + f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" |
| 154 | + ) |
| 155 | + wheel_path = _cache_root() / wheel_filename |
| 156 | + extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) |
| 157 | + # else will never be met though |
| 158 | + lib_filename = ( |
| 159 | + "libnvinfer_plugin_tensorrt_llm.so" |
| 160 | + if "linux" in platform_system |
| 161 | + else "libnvinfer_plugin_tensorrt_llm.dll" |
| 162 | + ) |
| 163 | + # eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so |
| 164 | + plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename |
| 165 | + |
| 166 | + if plugin_lib_path.exists(): |
| 167 | + return str(plugin_lib_path) |
| 168 | + |
| 169 | + wheel_path.parent.mkdir(parents=True, exist_ok=True) |
| 170 | + extract_dir.mkdir(parents=True, exist_ok=True) |
| 171 | + |
| 172 | + if not wheel_path.exists(): |
| 173 | + base_url = "https://pypi.nvidia.com/tensorrt-llm/" |
| 174 | + download_url = base_url + wheel_filename |
| 175 | + try: |
| 176 | + logger.debug(f"Downloading {download_url} ...") |
| 177 | + urllib.request.urlretrieve(download_url, wheel_path) |
| 178 | + logger.debug("Download succeeded and TRT-LLM wheel is now present") |
| 179 | + except urllib.error.HTTPError as e: |
| 180 | + logger.error( |
| 181 | + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" |
| 182 | + ) |
| 183 | + except urllib.error.URLError as e: |
| 184 | + logger.error( |
| 185 | + f"URL error when trying to download {download_url}: {e.reason}" |
| 186 | + ) |
| 187 | + except OSError as e: |
| 188 | + logger.error(f"Local file write error: {e}") |
| 189 | + |
| 190 | + try: |
| 191 | + import zipfile |
| 192 | + except ImportError as e: |
| 193 | + raise ImportError( |
| 194 | + "zipfile module is required but not found. Please install zipfile" |
| 195 | + ) |
| 196 | + try: |
| 197 | + with zipfile.ZipFile(wheel_path) as zip_ref: |
| 198 | + zip_ref.extractall(extract_dir) |
| 199 | + logger.debug(f"Extracted wheel to {extract_dir}") |
| 200 | + except FileNotFoundError as e: |
| 201 | + # This should capture the errors in the download failure above |
| 202 | + logger.error(f"Wheel file not found at {wheel_path}: {e}") |
| 203 | + raise RuntimeError( |
| 204 | + f"Failed to find downloaded wheel file at {wheel_path}" |
| 205 | + ) from e |
| 206 | + except zipfile.BadZipFile as e: |
| 207 | + logger.error(f"Invalid or corrupted wheel file: {e}") |
| 208 | + raise RuntimeError( |
| 209 | + "Downloaded wheel file is corrupted or not a valid zip archive" |
| 210 | + ) from e |
| 211 | + except Exception as e: |
| 212 | + logger.error(f"Unexpected error while extracting wheel: {e}") |
| 213 | + raise RuntimeError( |
| 214 | + "Unexpected error during extraction of TensorRT-LLM wheel" |
| 215 | + ) from e |
| 216 | + |
| 217 | + try: |
| 218 | + wheel_path.unlink(missing_ok=True) |
| 219 | + logger.debug(f"Deleted wheel file: {wheel_path}") |
| 220 | + except Exception as e: |
| 221 | + logger.warning(f"Could not delete wheel file {wheel_path}: {e}") |
| 222 | + if not plugin_lib_path.exists(): |
| 223 | + logger.error( |
| 224 | + f"Plugin library not found at expected location: {plugin_lib_path}" |
| 225 | + ) |
| 226 | + return None |
| 227 | + |
| 228 | + return str(plugin_lib_path) |
| 229 | + |
| 230 | + |
| 231 | +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: |
| 232 | + """ |
| 233 | + Loads and initializes the TensorRT-LLM plugin from the given shared library path. |
| 234 | +
|
| 235 | + Args: |
| 236 | + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. |
| 237 | +
|
| 238 | + Returns: |
| 239 | + bool: True if successful, False otherwise. |
| 240 | + """ |
| 241 | + try: |
| 242 | + handle = ctypes.CDLL(plugin_lib_path) |
| 243 | + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
| 244 | + except OSError as e_os_error: |
| 245 | + if "libmpi" in str(e_os_error): |
| 246 | + logger.warning( |
| 247 | + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", |
| 248 | + exc_info=e_os_error, |
| 249 | + ) |
| 250 | + else: |
| 251 | + logger.warning( |
| 252 | + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " |
| 253 | + f"Ensure the path is correct and the library is compatible.", |
| 254 | + exc_info=e_os_error, |
| 255 | + ) |
| 256 | + return False |
| 257 | + |
| 258 | + try: |
| 259 | + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
| 260 | + handle.initTrtLlmPlugins.restype = ctypes.c_bool |
| 261 | + except AttributeError as e_plugin_unavailable: |
| 262 | + logger.warning( |
| 263 | + "Unable to initialize the TensorRT-LLM plugin library", |
| 264 | + exc_info=e_plugin_unavailable, |
| 265 | + ) |
| 266 | + return False |
| 267 | + |
| 268 | + try: |
| 269 | + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): |
| 270 | + logger.info("TensorRT-LLM plugin successfully initialized") |
| 271 | + return True |
| 272 | + else: |
| 273 | + logger.warning("TensorRT-LLM plugin library failed in initialization") |
| 274 | + return False |
| 275 | + except Exception as e_initialization_error: |
| 276 | + logger.warning( |
| 277 | + "Exception occurred during TensorRT-LLM plugin library initialization", |
| 278 | + exc_info=e_initialization_error, |
| 279 | + ) |
| 280 | + return False |
| 281 | + return False |
| 282 | + |
| 283 | + |
| 284 | +def load_tensorrt_llm_for_nccl() -> bool: |
| 285 | + """ |
| 286 | + Attempts to load the TensorRT-LLM plugin and initialize it. |
| 287 | + Either the env variable TRTLLM_PLUGINS_PATH can specify the path |
| 288 | + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it |
| 289 | +
|
| 290 | + Returns: |
| 291 | + bool: True if the plugin was successfully loaded and initialized, False otherwise. |
| 292 | + """ |
| 293 | + if not is_platform_supported_for_trtllm(): |
| 294 | + return False |
| 295 | + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
| 296 | + |
| 297 | + if plugin_lib_path: |
| 298 | + return load_and_initialize_trtllm_plugin(plugin_lib_path) |
| 299 | + else: |
| 300 | + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user |
| 301 | + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( |
| 302 | + "1", |
| 303 | + "true", |
| 304 | + "yes", |
| 305 | + "on", |
| 306 | + ) |
| 307 | + if not use_trtllm_plugin: |
| 308 | + logger.warning( |
| 309 | + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" |
| 310 | + ) |
| 311 | + return False |
| 312 | + |
| 313 | + plugin_lib_path = download_and_get_plugin_lib_path() |
| 314 | + return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] |
| 315 | + return False |
0 commit comments