From f532f651bbd8aa7b0fbf93354f6e4c2b542b9b2c Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 22 Aug 2025 07:30:58 -0700 Subject: [PATCH 1/2] [ET-VK][ez] Fix partitioner logic of finding keepdim arg of reduce ops Title says it all. For reduce ops, their signature are not all alike so some extra legwork needs to be done to identify specific arguments that need to be checked. Also included a small update to partitioner logging to improve debuggability. Differential Revision: [D80741737](https://our.internmc.facebook.com/intern/diff/D80741737/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 15 +++++++++------ backends/vulkan/partitioner/vulkan_partitioner.py | 2 +- backends/vulkan/utils.py | 2 ++ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b7f8f3de955..a6cc59e26f0 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool: # If we can't get memory layout information, we'll assume the dims aren't packed pass - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: + def try_find_keepdim_arg(node: torch.fx.Node) -> bool: + for arg in node.args: + if isinstance(arg, bool): + return arg + + # Assume false by default return False - if len(node.args) > 2: - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: - return False + keepdim = try_find_keepdim_arg(node) + if isinstance(keepdim, bool) and not keepdim: + return False return True diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 1b5ff0a44e4..04a1a500b64 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( - f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}" + f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}" ) def is_node_supported( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1765f0b5e1c..bc03860ed3f 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str: assert isinstance(node.meta["val"], (list, tuple)) return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" else: + if "val" not in node.meta: + return str(node) return str(node.meta["val"]) From 05a62d0eea908388dfbc2e1d889ae946bdd90d9a Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 22 Aug 2025 14:12:09 -0700 Subject: [PATCH 2/2] Update on "[ET-VK][ez] Fix partitioner logic of finding keepdim arg of reduce ops" Title says it all. For reduce ops, their signature are not all alike so some extra legwork needs to be done to identify specific arguments that need to be checked. Also included a small update to partitioner logging to improve debuggability. Differential Revision: [D80741737](https://our.internmc.facebook.com/intern/diff/D80741737/) [ghstack-poisoned]