From 9cd09fd3c02be4697eedc8bd8e9a1cb1d1315f2b Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 25 Jul 2025 18:23:44 +0100 Subject: [PATCH 1/2] =?UTF-8?q?Revert=20"[SWDEV-543214]=20Set=20max=5Fnumw?= =?UTF-8?q?arps=20based=20on=20warp=5Fsize=20instead=20of=20hardcod?= =?UTF-8?q?=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit be95f40220f322bf1a00ff26c6546b23b9828ca0. (cherry picked from commit b32459d06becc98265a99286c9bada4e6b32dc8f) --- torch/_inductor/runtime/coordinate_descent_tuner.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index c8d36da3ebc9c..b41ca81ebdfc6 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -2,7 +2,6 @@ import copy import itertools import logging -from functools import cached_property from typing import Callable, Optional, TYPE_CHECKING from .hints import TRITON_MAX_BLOCK @@ -61,15 +60,10 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None return min(max_block, size_hint) if size_hint is not None else max_block - @cached_property def get_warpsmax(self): - # CUDA/ROCm has a maximum of 1024 threads per block - from torch.cuda import current_device, get_device_properties, is_available - - warp_size = ( - get_device_properties(current_device()).warp_size if is_available() else 32 - ) - return 1024 // warp_size + # Currently, CUDA has a maximum of 1024 threads, so 32 is the max + # number of warps. + return 1024 // 32 def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing From 74321c23acc2f4782febd085096ffea32f57a539 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Mon, 28 Jul 2025 11:13:36 -0500 Subject: [PATCH 2/2] Set max_numwarps based on warp_size instead of hardcoding --- torch/_inductor/runtime/coordinate_descent_tuner.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index b41ca81ebdfc6..b7dd523af87be 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -3,6 +3,7 @@ import itertools import logging from typing import Callable, Optional, TYPE_CHECKING +from functools import lru_cache from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -60,10 +61,15 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None return min(max_block, size_hint) if size_hint is not None else max_block + @lru_cache(maxsize=1) def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing