Skip to content

Commit 6dfb740

Browse files
committed
fixing circular imports
1 parent 6bbd852 commit 6dfb740

File tree

3 files changed

+265
-261
lines changed

3 files changed

+265
-261
lines changed

py/torch_tensorrt/_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import tensorrt
88
from torch_tensorrt._utils import (
99
check_cross_compile_trt_win_lib,
10+
load_tensorrt_llm_for_nccl,
1011
sanitized_torch_version,
1112
)
12-
from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl
1313

1414
from packaging import version
1515

py/torch_tensorrt/_utils.py

Lines changed: 264 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1+
import ctypes
2+
import getpass
3+
import logging
4+
import os
5+
import platform
16
import sys
2-
from typing import Any
7+
import tempfile
8+
import urllib.request
9+
from pathlib import Path
10+
from typing import Any, Optional
311

412
import tensorrt as trt
513
import torch
14+
from _version import __tensorrt_llm_version_
15+
16+
logger = logging.getLogger(__name__)
17+
18+
_WHL_CPYTHON_VERSION = "cp310"
619

720

821
def sanitized_torch_version() -> Any:
@@ -50,3 +63,253 @@ def is_tensorrt_version_supported(min_version: str) -> bool:
5063
except (ImportError, ValueError):
5164
# If tensorrt is not installed or version cannot be determined
5265
return False
66+
67+
68+
def is_thor() -> bool:
69+
if torch.cuda.get_device_capability() in [(11, 0)]:
70+
return True
71+
return False
72+
73+
74+
def is_platform_supported_for_trtllm() -> bool:
75+
"""
76+
Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
77+
78+
Returns:
79+
bool: True if supported, False otherwise.
80+
81+
Unsupported:
82+
- Windows platforms
83+
- Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
84+
- CUDA 13 not supported
85+
"""
86+
system = platform.system().lower()
87+
machine = platform.machine().lower()
88+
release = platform.release().lower()
89+
90+
if "windows" in system:
91+
logger.info(
92+
"TensorRT-LLM plugins for NCCL backend are not supported on Windows."
93+
)
94+
return False
95+
96+
if machine == "aarch64" and "tegra" in release or is_thor():
97+
logger.info(
98+
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) or Thor devices."
99+
)
100+
return False
101+
102+
try:
103+
cuda_version = torch.version.cuda # e.g., "12.4" or "13.0"
104+
if cuda_version is None:
105+
logger.error(
106+
"This pytorch build does not support CUDA, please reinstall pytorch with CUDA support"
107+
)
108+
return False
109+
110+
major, minor = map(int, cuda_version.split("."))
111+
if major != 12:
112+
logger.error(
113+
"CUDA 13 is not supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support"
114+
)
115+
return False
116+
117+
return True
118+
119+
except Exception as e:
120+
logger.warning(f"Failed to detect CUDA version: {e}")
121+
return False
122+
123+
return True
124+
125+
126+
def _cache_root() -> Path:
127+
username = getpass.getuser()
128+
return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
129+
130+
131+
def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
132+
return (
133+
_cache_root()
134+
/ "trtllm"
135+
/ f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}"
136+
)
137+
138+
139+
def download_and_get_plugin_lib_path() -> Optional[str]:
140+
"""
141+
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
142+
143+
Args:
144+
platform (str): Platform identifier (e.g., 'linux_x86_64')
145+
146+
Returns:
147+
Optional[str]: Path to shared library or None if operation fails.
148+
"""
149+
platform_system = platform.system().lower()
150+
platform_machine = platform.machine().lower()
151+
wheel_filename = (
152+
f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-"
153+
f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl"
154+
)
155+
wheel_path = _cache_root() / wheel_filename
156+
extract_dir = _extracted_dir_trtllm(platform_system, platform_machine)
157+
# else will never be met though
158+
lib_filename = (
159+
"libnvinfer_plugin_tensorrt_llm.so"
160+
if "linux" in platform_system
161+
else "libnvinfer_plugin_tensorrt_llm.dll"
162+
)
163+
# eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
164+
plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename
165+
166+
if plugin_lib_path.exists():
167+
return str(plugin_lib_path)
168+
169+
wheel_path.parent.mkdir(parents=True, exist_ok=True)
170+
extract_dir.mkdir(parents=True, exist_ok=True)
171+
172+
if not wheel_path.exists():
173+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
174+
download_url = base_url + wheel_filename
175+
try:
176+
logger.debug(f"Downloading {download_url} ...")
177+
urllib.request.urlretrieve(download_url, wheel_path)
178+
logger.debug("Download succeeded and TRT-LLM wheel is now present")
179+
except urllib.error.HTTPError as e:
180+
logger.error(
181+
f"HTTP error {e.code} when trying to download {download_url}: {e.reason}"
182+
)
183+
except urllib.error.URLError as e:
184+
logger.error(
185+
f"URL error when trying to download {download_url}: {e.reason}"
186+
)
187+
except OSError as e:
188+
logger.error(f"Local file write error: {e}")
189+
190+
try:
191+
import zipfile
192+
except ImportError as e:
193+
raise ImportError(
194+
"zipfile module is required but not found. Please install zipfile"
195+
)
196+
try:
197+
with zipfile.ZipFile(wheel_path) as zip_ref:
198+
zip_ref.extractall(extract_dir)
199+
logger.debug(f"Extracted wheel to {extract_dir}")
200+
except FileNotFoundError as e:
201+
# This should capture the errors in the download failure above
202+
logger.error(f"Wheel file not found at {wheel_path}: {e}")
203+
raise RuntimeError(
204+
f"Failed to find downloaded wheel file at {wheel_path}"
205+
) from e
206+
except zipfile.BadZipFile as e:
207+
logger.error(f"Invalid or corrupted wheel file: {e}")
208+
raise RuntimeError(
209+
"Downloaded wheel file is corrupted or not a valid zip archive"
210+
) from e
211+
except Exception as e:
212+
logger.error(f"Unexpected error while extracting wheel: {e}")
213+
raise RuntimeError(
214+
"Unexpected error during extraction of TensorRT-LLM wheel"
215+
) from e
216+
217+
try:
218+
wheel_path.unlink(missing_ok=True)
219+
logger.debug(f"Deleted wheel file: {wheel_path}")
220+
except Exception as e:
221+
logger.warning(f"Could not delete wheel file {wheel_path}: {e}")
222+
if not plugin_lib_path.exists():
223+
logger.error(
224+
f"Plugin library not found at expected location: {plugin_lib_path}"
225+
)
226+
return None
227+
228+
return str(plugin_lib_path)
229+
230+
231+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
232+
"""
233+
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
234+
235+
Args:
236+
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
237+
238+
Returns:
239+
bool: True if successful, False otherwise.
240+
"""
241+
try:
242+
handle = ctypes.CDLL(plugin_lib_path)
243+
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
244+
except OSError as e_os_error:
245+
if "libmpi" in str(e_os_error):
246+
logger.warning(
247+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)",
248+
exc_info=e_os_error,
249+
)
250+
else:
251+
logger.warning(
252+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
253+
f"Ensure the path is correct and the library is compatible.",
254+
exc_info=e_os_error,
255+
)
256+
return False
257+
258+
try:
259+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
260+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
261+
except AttributeError as e_plugin_unavailable:
262+
logger.warning(
263+
"Unable to initialize the TensorRT-LLM plugin library",
264+
exc_info=e_plugin_unavailable,
265+
)
266+
return False
267+
268+
try:
269+
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
270+
logger.info("TensorRT-LLM plugin successfully initialized")
271+
return True
272+
else:
273+
logger.warning("TensorRT-LLM plugin library failed in initialization")
274+
return False
275+
except Exception as e_initialization_error:
276+
logger.warning(
277+
"Exception occurred during TensorRT-LLM plugin library initialization",
278+
exc_info=e_initialization_error,
279+
)
280+
return False
281+
return False
282+
283+
284+
def load_tensorrt_llm_for_nccl() -> bool:
285+
"""
286+
Attempts to load the TensorRT-LLM plugin and initialize it.
287+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
288+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
289+
290+
Returns:
291+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
292+
"""
293+
if not is_platform_supported_for_trtllm():
294+
return False
295+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
296+
297+
if plugin_lib_path:
298+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
299+
else:
300+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
301+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
302+
"1",
303+
"true",
304+
"yes",
305+
"on",
306+
)
307+
if not use_trtllm_plugin:
308+
logger.warning(
309+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
310+
)
311+
return False
312+
313+
plugin_lib_path = download_and_get_plugin_lib_path()
314+
return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type]
315+
return False

0 commit comments

Comments
 (0)