Skip to content

Commit 8dd657c

Browse files
committed
Addressing review comments- include in enabled feature and error logging. Pending- check support on Thor and sbsa
1 parent 9255c4a commit 8dd657c

File tree

3 files changed

+54
-37
lines changed

3 files changed

+54
-37
lines changed

py/torch_tensorrt/_features.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
check_cross_compile_trt_win_lib,
1010
sanitized_torch_version,
1111
)
12+
from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl
1213

1314
from packaging import version
1415

@@ -23,6 +24,7 @@
2324
"qdp_plugin",
2425
"windows_cross_compile",
2526
"tensorrt_rtx",
27+
"trtllm_for_nccl",
2628
],
2729
)
2830

@@ -48,6 +50,7 @@
4850
_FX_FE_AVAIL = False if _TENSORRT_RTX else True
4951
_REFIT_AVAIL = True
5052
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib()
53+
_TRTLLM_AVAIL = load_tensorrt_llm_for_nccl()
5154

5255
if importlib.util.find_spec("tensorrt.plugin"):
5356
_QDP_PLUGIN_AVAIL = True
@@ -63,6 +66,7 @@
6366
_QDP_PLUGIN_AVAIL,
6467
_WINDOWS_CROSS_COMPILE,
6568
_TENSORRT_RTX,
69+
_TRTLLM_AVAIL,
6670
)
6771

6872
T = TypeVar("T")
@@ -158,6 +162,18 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
158162
return wrapper
159163

160164

165+
def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]:
166+
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
167+
if ENABLED_FEATURES.trtllm_for_nccl:
168+
return f(*args, **kwargs)
169+
else:
170+
raise NotImplementedError(
171+
"TensorRT-LLM plugins for NCCL backend could not be loaded"
172+
)
173+
174+
return wrapper
175+
176+
161177
def for_all_methods(
162178
decorator: Callable[..., Any], exclude: Optional[List[str]] = None
163179
) -> Callable[..., Any]:

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import tensorrt as trt
77
from torch.fx.node import Argument, Target
8+
from torch_tensorrt._features import needs_trtllm_for_nccl
89
from torch_tensorrt.dynamo._SourceIR import SourceIR
910
from torch_tensorrt.dynamo.conversion import impl
1011
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -15,45 +16,41 @@
1516
tensorrt_fused_nccl_all_gather_op,
1617
tensorrt_fused_nccl_reduce_scatter_op,
1718
)
18-
from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl
1919

2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

22-
if load_tensorrt_llm_for_nccl():
2322

24-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
25-
def fused_nccl_gather(
26-
ctx: ConversionContext,
27-
target: Target,
28-
args: Tuple[Argument, ...],
29-
kwargs: Dict[str, Argument],
30-
name: str,
31-
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
32-
return impl.nccl_ops.nccl_gather(
33-
ctx,
34-
target,
35-
SourceIR.ATEN,
36-
name,
37-
[args[0]],
38-
)
23+
@needs_trtllm_for_nccl
24+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
25+
def fused_nccl_gather(
26+
ctx: ConversionContext,
27+
target: Target,
28+
args: Tuple[Argument, ...],
29+
kwargs: Dict[str, Argument],
30+
name: str,
31+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
32+
return impl.nccl_ops.nccl_gather(
33+
ctx,
34+
target,
35+
SourceIR.ATEN,
36+
name,
37+
[args[0]],
38+
)
3939

40-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
41-
def fused_nccl_reduce_scatter(
42-
ctx: ConversionContext,
43-
target: Target,
44-
args: Tuple[Argument, ...],
45-
kwargs: Dict[str, Argument],
46-
name: str,
47-
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
48-
return impl.nccl_ops.nccl_reduce_scatter(
49-
ctx,
50-
target,
51-
SourceIR.ATEN,
52-
name,
53-
[args[0]],
54-
)
5540

56-
else:
57-
_LOGGER.debug(
58-
"Did not load torch.distributed converters since TensorRT-LLM is not available"
41+
@needs_trtllm_for_nccl
42+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
43+
def fused_nccl_reduce_scatter(
44+
ctx: ConversionContext,
45+
target: Target,
46+
args: Tuple[Argument, ...],
47+
kwargs: Dict[str, Argument],
48+
name: str,
49+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
50+
return impl.nccl_ops.nccl_reduce_scatter(
51+
ctx,
52+
target,
53+
SourceIR.ATEN,
54+
name,
55+
[args[0]],
5956
)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,12 +907,16 @@ def is_platform_supported_for_trtllm() -> bool:
907907
try:
908908
cuda_version = torch.version.cuda # e.g., "12.4" or "13.0"
909909
if cuda_version is None:
910-
logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.")
910+
logger.error(
911+
"This pytorch build does not support CUDA, please reinstall pytorch with CUDA support"
912+
)
911913
return False
912914

913915
major, minor = map(int, cuda_version.split("."))
914916
if major != 12:
915-
logger.warning("CUDA 13 is not supported for TRT-LLM plugins.")
917+
logger.error(
918+
"CUDA 13 is not supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support"
919+
)
916920
return False
917921

918922
return True

0 commit comments

Comments
 (0)