Skip to content

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Sep 22, 2025

  1. TRT-LLM download utility. USE_TRTLLM_PLUGINS=1 should trigger the download mechanism.
  2. wheel downloaded to /tmp/torch_tensorrt_/ and unzipped
  3. plugin_lib_path is set to here - /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so.
    Across runs wheel is removed, while .so file is retained

@meta-cla meta-cla bot added the cla signed label Sep 22, 2025
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: build system Issues re: Build system component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 22, 2025
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
if load_tensorrt_llm_for_nccl():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to use the enabled features system for this rather than a stand alone function

Unsupported:
- Windows platforms
- Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thor also not supported by TRT-LLM right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah Thor and sbsa should support NCCL, but TRT-LLM I am not aware. Will include Thor in the list of unsupported platform.
A followup question, what about sbsa? I see on TRT-LLM page that they are supported on Blackwell, but that does not imply sbsa support right (can be supported on B200 - non sbsa vs GB200 - sbsa).


if machine == "aarch64" and "tegra" in release:
logger.info(
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit the error message here to include thor

try:
cuda_version = torch.version.cuda # e.g., "12.4" or "13.0"
if cuda_version is None:
logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat misleading because the actual error is that the pytorch install does not support cuda.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if that is the case would this be an error? What invokes this function? Should the user continue to be able to run? Would they be under the assumption that TRT-LLM plugins would be available?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will change the error message.
In that case cuda runtime is not available, but I assume we would hit an error before only before reaching this point. Wrt to this function we won't be able to verify if the CUDA is 12.X or 13.X. Should I remove this check altogether?


major, minor = map(int, cuda_version.split("."))
if major != 12:
logger.warning("CUDA 13 is not supported for TRT-LLM plugins.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not currently supported. Same comment as above though, Seems to me this is at least log error, but the question is if we should kill the process. If the program will not run as intended we should otherwise its still an error but we can continue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change the error message to add currently
Its more like then this function will return a false
load_tensorrt_llm_for_nccl() calls is_platform_supported_for_trtllm() which will return false and the converter will be unsupported.

return False


def load_tensorrt_llm_for_nccl() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be in the enabled features system. And should register the feature for other parts of the library to query against

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make this change

@narendasan
Copy link
Collaborator

narendasan commented Sep 26, 2025 via email

@apbose apbose force-pushed the abose/trt_llm_installtion branch from 8dd657c to cee5c7a Compare October 1, 2025 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants