diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index a81bfcf039..e061aebbd1 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -344,6 +344,40 @@ jobs: python -m pytest -m "not critical" -ra -n auto --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_torch_compile_dyn_models_tests_results.xml --ir torch_compile models/test_dyn_models.py popd + L1-dynamo-distributed-tests: + name: Test dynamo distributed [Python] + needs: [filter-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh + uses: ./.github/workflows/linux-test.yml + with: + job-name: tests-py-dynamo-distributed + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.filter-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + set -euo pipefail + export USE_HOST_DEPS=1 + export CI_BUILD=1 + export USE_TRTLLM_PLUGINS=1 + dnf install -y mpich mpich-devel openmpi openmpi-devel + pushd . + cd tests/py + cd dynamo + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_distributed_test_results.xml distributed/test_nccl_ops.py + popd + + L2-dynamo-compile-tests: name: L2 dynamo compile tests needs: [filter-matrix, build, L1-dynamo-compile-tests, L1-dynamo-core-tests, L1-torch-compile-tests, L1-torchscript-tests] diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 1159951385..8f3df6e509 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,3 +1,4 @@ __cuda_version__: "12.8" __tensorrt_version__: "10.13.3" __tensorrt_rtx_version__: "1.0.0" +__tensorrt_llm_version__: "0.17.0.post1" diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index f7e4e91626..03cf4256ec 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -7,6 +7,7 @@ import tensorrt from torch_tensorrt._utils import ( check_cross_compile_trt_win_lib, + load_tensorrt_llm_for_nccl, sanitized_torch_version, ) @@ -23,6 +24,7 @@ "qdp_plugin", "windows_cross_compile", "tensorrt_rtx", + "trtllm_for_nccl", ], ) @@ -48,6 +50,7 @@ _FX_FE_AVAIL = False if _TENSORRT_RTX else True _REFIT_AVAIL = True _WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib() +_TRTLLM_AVAIL = load_tensorrt_llm_for_nccl() if importlib.util.find_spec("tensorrt.plugin"): _QDP_PLUGIN_AVAIL = True @@ -63,6 +66,7 @@ _QDP_PLUGIN_AVAIL, _WINDOWS_CROSS_COMPILE, _TENSORRT_RTX, + _TRTLLM_AVAIL, ) T = TypeVar("T") @@ -158,6 +162,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper +def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if ENABLED_FEATURES.trtllm_for_nccl: + return f(*args, **kwargs) + else: + + def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + raise NotImplementedError( + "Refit feature is currently not available in Python 3.13 or higher" + ) + + return not_implemented(*args, **kwargs) + + return wrapper + + def for_all_methods( decorator: Callable[..., Any], exclude: Optional[List[str]] = None ) -> Callable[..., Any]: diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index b981fb325a..e56c6f0742 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -1,9 +1,22 @@ +import ctypes +import getpass +import logging +import os +import platform import sys -from typing import Any +import tempfile +import urllib.request +from pathlib import Path +from typing import Any, Optional import tensorrt as trt import torch +logger = logging.getLogger(__name__) + +_WHL_CPYTHON_VERSION = "cp310" +_TENSORRT_LLM_VERSION_ = "0.17.0.post1" + def sanitized_torch_version() -> Any: return ( @@ -50,3 +63,253 @@ def is_tensorrt_version_supported(min_version: str) -> bool: except (ImportError, ValueError): # If tensorrt is not installed or version cannot be determined return False + + +def is_thor() -> bool: + if torch.cuda.get_device_capability() in [(11, 0)]: + return True + return False + + +def is_platform_supported_for_trtllm() -> bool: + """ + Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. + + Returns: + bool: True if supported, False otherwise. + + Unsupported: + - Windows platforms + - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) + - CUDA 13 not supported + """ + system = platform.system().lower() + machine = platform.machine().lower() + release = platform.release().lower() + + if "windows" in system: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Windows." + ) + return False + + if machine == "aarch64" and "tegra" in release or is_thor(): + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) or Thor devices." + ) + return False + + try: + cuda_version = torch.version.cuda # e.g., "12.4" or "13.0" + if cuda_version is None: + logger.error( + "This pytorch build does not support CUDA, please reinstall pytorch with CUDA support" + ) + return False + + major, minor = map(int, cuda_version.split(".")) + if major != 12: + logger.error( + "CUDA 13 is not supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support" + ) + return False + + return True + + except Exception as e: + logger.warning(f"Failed to detect CUDA version: {e}") + return False + + return True + + +def _cache_root() -> Path: + username = getpass.getuser() + return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" + + +def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: + return ( + _cache_root() + / "trtllm" + / f"{_TENSORRT_LLM_VERSION_}_{platform_system}_{platform_machine}" + ) + + +def download_and_get_plugin_lib_path() -> Optional[str]: + """ + Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. + + Args: + platform (str): Platform identifier (e.g., 'linux_x86_64') + + Returns: + Optional[str]: Path to shared library or None if operation fails. + """ + platform_system = platform.system().lower() + platform_machine = platform.machine().lower() + wheel_filename = ( + f"tensorrt_llm-{_TENSORRT_LLM_VERSION_}-{_WHL_CPYTHON_VERSION}-" + f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" + ) + wheel_path = _cache_root() / wheel_filename + extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) + # else will never be met though + lib_filename = ( + "libnvinfer_plugin_tensorrt_llm.so" + if "linux" in platform_system + else "libnvinfer_plugin_tensorrt_llm.dll" + ) + # eg: /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so + plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename + + if plugin_lib_path.exists(): + return str(plugin_lib_path) + + wheel_path.parent.mkdir(parents=True, exist_ok=True) + extract_dir.mkdir(parents=True, exist_ok=True) + + if not wheel_path.exists(): + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + download_url = base_url + wheel_filename + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, wheel_path) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"Extracted wheel to {extract_dir}") + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + + try: + wheel_path.unlink(missing_ok=True) + logger.debug(f"Deleted wheel file: {wheel_path}") + except Exception as e: + logger.warning(f"Could not delete wheel file {wheel_path}: {e}") + if not plugin_lib_path.exists(): + logger.error( + f"Plugin library not found at expected location: {plugin_lib_path}" + ) + return None + + return str(plugin_lib_path) + + +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: + """ + Loads and initializes the TensorRT-LLM plugin from the given shared library path. + + Args: + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. + + Returns: + bool: True if successful, False otherwise. + """ + try: + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", + exc_info=e_os_error, + ) + return False + + try: + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False + + +def load_tensorrt_llm_for_nccl() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + if not is_platform_supported_for_trtllm(): + return False + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + + if plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + else: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + + plugin_lib_path = download_and_get_plugin_lib_path() + return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] + return False diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0dc4654db0..a83a622fdf 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -103,6 +103,7 @@ def cross_compile_for_windows( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -176,6 +177,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -330,6 +332,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } # disable the following settings is not supported for cross compilation for windows feature @@ -430,6 +433,7 @@ def compile( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -506,6 +510,7 @@ def compile( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -674,6 +679,7 @@ def compile( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) @@ -1045,6 +1051,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1118,6 +1125,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model. **kwargs: Any, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs @@ -1286,6 +1294,7 @@ def convert_exported_program_to_serialized_trt_engine( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 3828f97f99..094de488ec 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,8 +1,6 @@ import collections -import ctypes import functools import logging -import os from typing import ( Any, Callable, @@ -1124,69 +1122,6 @@ def args_bounds_check( return args[i] if len(args) > i and args[i] is not None else replacement -def load_tensorrt_llm() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", - ) - return False - - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False - - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False - - def promote_trt_tensors_to_same_dtype( ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str ) -> tuple[TRTTensor, TRTTensor]: diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 1442c2b17b..db14e3528b 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -5,55 +5,52 @@ import tensorrt as trt from torch.fx.node import Argument, Target +from torch_tensorrt._features import needs_trtllm_for_nccl from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm +from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_gather_op, + tensorrt_fused_nccl_reduce_scatter_op, +) _LOGGER: logging.Logger = logging.getLogger(__name__) -if load_tensorrt_llm(): - from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( - tensorrt_fused_nccl_all_gather_op, - tensorrt_fused_nccl_reduce_scatter_op, - ) - @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) - def fused_nccl_gather( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_gather( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], - ) +@needs_trtllm_for_nccl +@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) +def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_gather( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) - @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) - def fused_nccl_reduce_scatter( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_reduce_scatter( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], - ) -else: - _LOGGER.debug( - "Did not load torch.distributed converters since TensorRT-LLM is not available" +@needs_trtllm_for_nccl +@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) +def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_reduce_scatter( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], ) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 564250e5ae..e7927bef84 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,7 +5,16 @@ import warnings from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import sympy @@ -90,11 +99,9 @@ def unified_dtype_converter( ) -> Union[np.dtype, torch.dtype, TRTDataType]: """ Convert TensorRT, Numpy, or Torch data types to any other of those data types. - Args: dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type. to (Frameworks): The framework to convert the data type to. - Returns: The equivalent data type in the requested framework. """ @@ -852,9 +859,3 @@ def is_tegra_platform() -> bool: if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: return True return False - - -def is_thor() -> bool: - if torch.cuda.get_device_capability() in [(11, 0)]: - return True - return False diff --git a/setup.py b/setup.py index 878e4de4ca..d487530626 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ __cuda_version__: str = "0.0" __tensorrt_version__: str = "0.0" __tensorrt_rtx_version__: str = "0.0" +__tensorrt_llm_version__: str = "0.0" LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") # CI_PIPELINE_ID is the environment variable set by DLFW ci build @@ -69,6 +70,7 @@ def load_dep_info(): global __cuda_version__ global __tensorrt_version__ global __tensorrt_rtx_version__ + global __tensorrt_llm_version__ with open("dev_dep_versions.yml", "r") as stream: versions = yaml.safe_load(stream) if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None: @@ -79,6 +81,7 @@ def load_dep_info(): __cuda_version__ = versions["__cuda_version__"] __tensorrt_version__ = versions["__tensorrt_version__"] __tensorrt_rtx_version__ = versions["__tensorrt_rtx_version__"] + __tensorrt_llm_version__ = versions["__tensorrt_llm_version__"] load_dep_info() @@ -245,6 +248,7 @@ def gen_version_file(): f.write('__cuda_version__ = "' + __cuda_version__ + '"\n') f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') f.write('__tensorrt_rtx_version__ = "' + __tensorrt_rtx_version__ + '"\n') + f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n') def copy_libtorchtrt(multilinux=False, rt_only=False): diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index e3062249fa..bc058aaaec 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -13,7 +13,6 @@ def set_environment_variables_pytest(): os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(29500) - os.environ["USE_TRTLLM_PLUGINS"] = "1" def initialize_logger(rank, logger_file_name): diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 89c94300b7..eafe16d455 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -1,42 +1,72 @@ import os +import unittest import torch import torch.distributed as dist import torch.nn as nn +from conversion.harness import DispatchTestCase from distributed_utils import set_environment_variables_pytest from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._utils import is_platform_supported_for_trtllm -set_environment_variables_pytest() -dist.init_process_group(backend="nccl", init_method="env://") -group = dist.new_group(ranks=[0]) -group_name = group.group_name -world_size = 1 -from conversion.harness import DispatchTestCase +class DistributedGatherModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(gathered_tensor) -class TestGatherNcclOpsConverter(DispatchTestCase): - @parameterized.expand([8]) - def test_nccl_ops(self, linear_layer_dim): - class DistributedGatherModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - def forward(self, x): - x = self.fc(x) - gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( - x, world_size, group_name - ) - gathered_tensor = torch.ops._c10d_functional.wait_tensor( - gathered_tensor - ) - return gathered_tensor +class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + out = torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(out) + +class TestNcclOpsConverter(DispatchTestCase): + @unittest.skipIf( + not is_platform_supported_for_trtllm(), + "Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.", + ) + @classmethod + def setUpClass(cls): + set_environment_variables_pytest() + cls.world_size = 1 + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + cls.group = dist.new_group(ranks=[0]) + cls.group_name = cls.group.group_name + + @classmethod + def tearDownClass(cls): + if dist.is_initialized(): + dist.destroy_process_group() + + @parameterized.expand([8]) + def test_nccl_ops_gather(self, linear_layer_dim): inputs = [torch.randn(1, linear_layer_dim).to("cuda")] self.run_test( - DistributedGatherModel(linear_layer_dim).cuda(), + DistributedGatherModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, @@ -44,28 +74,11 @@ def forward(self, x): @parameterized.expand([8]) def test_nccl_ops_scatter(self, linear_layer_dim): - - class DistributedReduceScatterModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - scatter_reduce_tensor = ( - torch.ops._c10d_functional.reduce_scatter_tensor( - x, "sum", world_size, group_name - ) - ) - scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( - scatter_reduce_tensor - ) - return scatter_reduce_tensor - inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] - self.run_test( - DistributedReduceScatterModel(linear_layer_dim).cuda(), + DistributedReduceScatterModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh index dd54700048..677d0cb9bc 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.sh +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -70,51 +70,6 @@ ensure_pytest_installed(){ echo "Setting up the environment" -OS="$(uname -s)" -ARCH="$(uname -m)" - - -#getting the file name for TensorRT-LLM download -if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" -elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" -else: - echo "Unsupported platform: OS=$OS ARCH=$ARCH - exit 1 -fi - -# Download the selected file -URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" -echo "Downloading $FILE from $URL..." - -#Installing wget -ensure_installed wget - -#Downloading the file -filename=$(basename "$URL") -if [ -f "$filename" ]; then - echo "File already exists: $filename" -else - wget "$URL" -fi -echo "Download complete: $FILE" - -UNZIP_DIR="tensorrt_llm_unzip" -if [[ ! -d "$UNZIP_DIR" ]]; then - echo "Creating directory: $UNZIP_DIR" - mkdir -p "$UNZIP_DIR" - echo "extracting $FILE to $UNZIP_DIR ..." - #Installing unzip - ensure_installed unzip - #unzip the TensorRT-LLM package - unzip -q "$FILE" -d "$UNZIP_DIR" - echo "Unzip complete" -fi - - -export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" -echo ${TRTLLM_PLUGINS_PATH} ensure_mpi_installed libmpich-dev ensure_mpi_installed libopenmpi-dev @@ -123,7 +78,7 @@ run_tests() { cd .. export PYTHONPATH=$(pwd) echo "Running pytest on distributed/test_nccl_ops.py..." - pytest distributed/test_nccl_ops.py + USE_TRTLLM_PLUGINS=1 pytest distributed/test_nccl_ops.py } run_mpi_tests(){