From 3ca29b0cb52c84e03265232ddf7e0491373d7bd5 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 23 Jan 2026 12:10:28 -0800 Subject: [PATCH 01/11] add template argument deduction module --- numbast/src/numbast/deduction.py | 360 +++++++++++++++++++++++++++++++ numbast/tests/test_deduction.py | 243 +++++++++++++++++++++ 2 files changed, 603 insertions(+) create mode 100644 numbast/src/numbast/deduction.py create mode 100644 numbast/tests/test_deduction.py diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py new file mode 100644 index 00000000..92706056 --- /dev/null +++ b/numbast/src/numbast/deduction.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Iterable +import os +import re + +from ast_canopy import pylibastcanopy +from ast_canopy.decl import Function, FunctionTemplate, StructMethod + +from numba.cuda import types as nbtypes + +from numbast.intent import compute_intent_plan +from numbast.types import to_c_type_str, to_numba_type + + +_SPACE_RE = re.compile(r"\s+") +_DEBUG_ENV_VAR = "NUMBAST_TAD_DEBUG" + + +def _debug_enabled(debug: bool | None) -> bool: + if debug is not None: + return debug + value = os.environ.get(_DEBUG_ENV_VAR, "") + return value.strip().lower() in ("1", "true", "yes", "on") + + +def _debug_print(debug: bool | None, msg: str) -> None: + if _debug_enabled(debug): + print(f"[numbast.deduction] {msg}") + + +def _normalize_cxx_type_str(type_str: str) -> str: + """Normalize C++ type strings for comparison.""" + if not type_str: + return type_str + type_str = type_str.strip() + type_str = _SPACE_RE.sub(" ", type_str) + # Remove whitespace around pointer/reference markers. + type_str = type_str.replace(" *", "*").replace("* ", "*") + type_str = type_str.replace(" &", "&").replace("& ", "&") + return type_str + + +def _apply_pass_ptr(cxx_param: str, pass_ptr: bool) -> str: + if not pass_ptr: + return cxx_param + if "*" in cxx_param: + return cxx_param + return _normalize_cxx_type_str(f"{cxx_param}*") + + +def _normalize_numba_arg_type(arg: nbtypes.Type) -> nbtypes.Type: + if isinstance(arg, nbtypes.Literal): + literal_type = getattr(arg, "literal_type", None) + if literal_type is not None: + return literal_type + return arg + + +def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None: + arg = _normalize_numba_arg_type(arg) + try: + return _normalize_cxx_type_str(to_c_type_str(arg)) + except Exception: + return None + + +def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool: + """Best-effort compatibility check for non-templated parameters.""" + nb_expected = to_numba_type(cxx_type) + if nb_expected is nbtypes.undefined: + return True + return nb_expected == _normalize_numba_arg_type(arg) + + +def _deduce_from_type_pattern( + cxx_type: str, arg_cxx: str, placeholders: Iterable[str] +) -> dict[str, str] | None: + placeholders_in_type = [p for p in placeholders if p in cxx_type] + if not placeholders_in_type: + return None + + pattern = re.escape(cxx_type) + order: list[str] = [] + for ph in placeholders_in_type: + pattern = pattern.replace(re.escape(ph), r"(.*?)") + order.append(ph) + + match = re.fullmatch(pattern, arg_cxx) + if not match: + return None + + deduced: dict[str, str] = {} + for ph, value in zip(order, match.groups()): + value = value.strip() + if not value: + return None + if ph in deduced and deduced[ph] != value: + return None + deduced[ph] = value + return deduced + + +def _replace_placeholders(type_str: str, replacements: dict[str, str]) -> str: + for key, value in replacements.items(): + type_str = type_str.replace(key, value) + return type_str + + +def _specialize_type( + ast_type: pylibastcanopy.Type, replacements: dict[str, str] +) -> pylibastcanopy.Type: + new_name = _replace_placeholders(ast_type.name, replacements) + new_unqualified = _replace_placeholders( + ast_type.unqualified_non_ref_type_name, replacements + ) + return pylibastcanopy.Type( + new_name, + new_unqualified, + ast_type.is_right_reference(), + ast_type.is_left_reference(), + ) + + +def _specialize_function( + func: Function, replacements: dict[str, str] +) -> Function: + new_return = _specialize_type(func.return_type, replacements) + new_params = [ + pylibastcanopy.ParamVar(p.name, _specialize_type(p.type_, replacements)) + for p in func.params + ] + + if isinstance(func, StructMethod): + return StructMethod( + func.name, + func.qual_name, + new_return, + new_params, + func.kind, + func.exec_space, + func.is_constexpr, + func.is_move_constructor, + func.mangled_name, + func.attributes, + func.parse_entry_point, + ) + + return Function( + func.name, + func.qual_name, + new_return, + new_params, + func.exec_space, + func.is_constexpr, + func.mangled_name, + func.attributes, + func.parse_entry_point, + ) + + +def _clone_function_template( + templ: FunctionTemplate, func: Function +) -> FunctionTemplate: + return FunctionTemplate( + templ.template_parameters, + templ.num_min_required_args, + func, + templ.qual_name, + templ.parse_entry_point, + ) + + +def _unresolved_placeholders( + func: Function, placeholder_names: Iterable[str] +) -> bool: + placeholders = tuple(placeholder_names) + if any( + p in func.return_type.unqualified_non_ref_type_name + for p in placeholders + ): + return True + for param in func.params: + if any( + p in param.type_.unqualified_non_ref_type_name for p in placeholders + ): + return True + return False + + +def deduce_templated_overloads( + *, + qualname: str, + overloads: list[FunctionTemplate], + args: tuple[nbtypes.Type, ...], + overrides: dict | None = None, + debug: bool | None = None, +) -> tuple[list[FunctionTemplate], list[Exception]]: + """ + Perform template argument deduction for templated method overloads. + + Returns a list of FunctionTemplate objects with fully-specialized + Function/Method types, plus any arg_intent-related errors encountered + while computing visible arity. + + Enable debug output by passing debug=True or setting the + NUMBAST_TAD_DEBUG=1 environment variable. + """ + specialized: list[FunctionTemplate] = [] + intent_errors: list[Exception] = [] + + _debug_print( + debug, + f"begin: {qualname}, overloads={len(overloads)}, args={len(args)}, " + f"overrides={'yes' if overrides else 'no'}", + ) + + for idx, templ in enumerate(overloads): + _debug_print( + debug, + f"overload[{idx}] {templ.function.name}: " + f"params={[p.type_.unqualified_non_ref_type_name for p in templ.function.params]}", + ) + if overrides is None: + visible_param_indices = tuple(range(len(templ.function.params))) + pass_ptr_mask = tuple(False for _ in visible_param_indices) + else: + try: + plan = compute_intent_plan( + params=templ.function.params, + param_types=templ.function.param_types, + overrides=overrides, + allow_out_return=True, + ) + except Exception as exc: + intent_errors.append(exc) + _debug_print( + debug, + f" intent plan error: {exc}", + ) + continue + visible_param_indices = plan.visible_param_indices + pass_ptr_mask = plan.pass_ptr_mask + _debug_print( + debug, + " intent plan: " + f"visible={plan.visible_param_indices}, " + f"out_return={plan.out_return_indices}, " + f"pass_ptr={plan.pass_ptr_mask}", + ) + + if len(visible_param_indices) != len(args): + _debug_print( + debug, + " skip: visible arity mismatch " + f"(visible={len(visible_param_indices)} vs args={len(args)})", + ) + continue + + placeholder_names = [ + tp.type_.name + for tp in templ.template_parameters + if tp.kind == pylibastcanopy.template_param_kind.type_ + ] + _debug_print( + debug, + f" placeholders={placeholder_names or 'none'}", + ) + mapping: dict[str, str] = {} + failed = False + + for vis_pos, (arg, param_idx) in enumerate( + zip(args, visible_param_indices) + ): + param = templ.function.params[param_idx] + cxx_param = _normalize_cxx_type_str( + param.type_.unqualified_non_ref_type_name + ) + pass_ptr = bool(pass_ptr_mask[vis_pos]) + cxx_param = _apply_pass_ptr(cxx_param, pass_ptr) + arg_cxx = _numba_arg_to_cxx_type(arg) + if arg_cxx is None: + _debug_print( + debug, + f" arg[{param_idx}] {param.name}: " + f"numba={arg} could not map to C++ type", + ) + failed = True + break + + _debug_print( + debug, + f" arg[{param_idx}] {param.name}: " + f"param={cxx_param}, arg={arg_cxx}, pass_ptr={pass_ptr}", + ) + + deduced = _deduce_from_type_pattern( + cxx_param, arg_cxx, placeholder_names + ) + if deduced is None: + if not _param_type_matches_arg(cxx_param, arg): + _debug_print( + debug, + f" mismatch: param={cxx_param}, arg={arg}", + ) + failed = True + break + _debug_print( + debug, + " no deduction needed; param matches arg", + ) + continue + + _debug_print( + debug, + f" deduced={deduced}", + ) + + for key, value in deduced.items(): + prev = mapping.get(key) + if prev is not None and prev != value: + _debug_print( + debug, + f" conflict: {key}={prev} vs {value}", + ) + failed = True + break + mapping[key] = value + if failed: + break + + if failed: + _debug_print(debug, " skip: deduction failed") + continue + + _debug_print(debug, f" mapping={mapping}") + specialized_func = _specialize_function(templ.function, mapping) + if _unresolved_placeholders(specialized_func, placeholder_names): + _debug_print( + debug, " skip: unresolved placeholders after specialize" + ) + continue + + _debug_print( + debug, + " specialized: " + f"return={specialized_func.return_type.unqualified_non_ref_type_name}, " + f"params={[p.type_.unqualified_non_ref_type_name for p in specialized_func.params]}", + ) + specialized.append(_clone_function_template(templ, specialized_func)) + + _debug_print( + debug, + f"end: {qualname}, specialized={len(specialized)}, intent_errors={len(intent_errors)}", + ) + return specialized, intent_errors diff --git a/numbast/tests/test_deduction.py b/numbast/tests/test_deduction.py new file mode 100644 index 00000000..d7d32368 --- /dev/null +++ b/numbast/tests/test_deduction.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import textwrap + +import pytest + +from ast_canopy import parse_declarations_from_source +from numba.cuda import types as nbtypes + +from numbast.deduction import deduce_templated_overloads + + +_CXX_SOURCE = textwrap.dedent( + """\ + #pragma once + + template + __device__ T add(T a, T b) { return a + b; } + + template + __device__ T add(T a, T b, T c) { return a + b + c; } + + template + __device__ T add_int(int a, T b) { return a + b; } + + template + __device__ void store_ptr(T *out, T value) { *out = value; } + + template + __device__ void store_ref(T &out, T value) { out = value; } + + template + __device__ T return_only(); + + template + __device__ void bad_out(T value) { (void)value; } + + struct Box { + template + __device__ T mul(T a, T b) const { return a * b; } + + template + __device__ void write(T &out, T value) const { out = value; } + }; + """ +) + + +@pytest.fixture(scope="module") +def deduction_decls(tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("deduction") + header_path = tmp_dir / "deduction_sample.cuh" + header_path.write_text(_CXX_SOURCE, encoding="utf-8") + decls = parse_declarations_from_source( + str(header_path), + [str(header_path)], + "sm_80", + verbose=False, + ) + return decls + + +def _get_function_templates(decls, name: str): + templs = [ + templ + for templ in decls.function_templates + if templ.function.name == name + ] + if not templs: + raise AssertionError(f"No function templates found for {name!r}") + return templs + + +def _get_struct_method_templates(decls, struct_name: str, method_name: str): + for struct in decls.structs: + if struct.name == struct_name: + templs = [ + templ + for templ in struct.templated_methods + if templ.function.name == method_name + ] + if not templs: + raise AssertionError( + f"No templated methods found for {struct_name}::{method_name}" + ) + return templs + raise AssertionError(f"Struct {struct_name!r} not found") + + +def test_overload_arity_selection(deduction_decls): + """Select the overload that matches the visible argument arity.""" + overloads = _get_function_templates(deduction_decls, "add") + specialized, intent_errors = deduce_templated_overloads( + qualname="add", + overloads=overloads, + args=(nbtypes.int32, nbtypes.int32), + overrides=None, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert len(func.params) == 2 + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "int", + "int", + ] + assert func.return_type.unqualified_non_ref_type_name == "int" + + +def test_conflicting_deduction_skips_overload(deduction_decls): + """Skip overloads when template placeholders deduce conflicting types.""" + overloads = _get_function_templates(deduction_decls, "add") + specialized, intent_errors = deduce_templated_overloads( + qualname="add", + overloads=overloads, + args=(nbtypes.int32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +def test_non_templated_param_requires_match(deduction_decls): + """Require exact matches for non-templated parameters.""" + overloads = _get_function_templates(deduction_decls, "add_int") + specialized, intent_errors = deduce_templated_overloads( + qualname="add_int", + overloads=overloads, + args=(nbtypes.int32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "int", + "float", + ] + assert func.return_type.unqualified_non_ref_type_name == "float" + + specialized, intent_errors = deduce_templated_overloads( + qualname="add_int", + overloads=overloads, + args=(nbtypes.float32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +def test_return_only_placeholder_skipped(deduction_decls): + """Skip overloads with unresolved placeholders only in return type.""" + overloads = _get_function_templates(deduction_decls, "return_only") + specialized, intent_errors = deduce_templated_overloads( + qualname="return_only", + overloads=overloads, + args=(), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +def test_pass_ptr_override_deduces_from_pointer(deduction_decls): + """Allow out_ptr overrides to pass pointers for reference params.""" + overloads = _get_function_templates(deduction_decls, "store_ref") + specialized, intent_errors = deduce_templated_overloads( + qualname="store_ref", + overloads=overloads, + args=(nbtypes.CPointer(nbtypes.int32), nbtypes.int32), + overrides={"out": "out_ptr"}, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "int", + "int", + ] + assert func.params[0].type_.is_left_reference() + assert not func.params[1].type_.is_left_reference() + + +def test_invalid_override_surfaces_error(deduction_decls): + """Return intent errors when overrides are incompatible with param types.""" + overloads = _get_function_templates(deduction_decls, "bad_out") + specialized, intent_errors = deduce_templated_overloads( + qualname="bad_out", + overloads=overloads, + args=(nbtypes.int32,), + overrides={"value": "out_ptr"}, + ) + + assert specialized == [] + assert intent_errors + err = intent_errors[0] + assert isinstance(err, ValueError) + assert "reference parameters" in str(err) + + +def test_struct_method_specialization(deduction_decls): + """Specialize templated struct methods into concrete types.""" + overloads = _get_struct_method_templates(deduction_decls, "Box", "mul") + specialized, intent_errors = deduce_templated_overloads( + qualname="Box.mul", + overloads=overloads, + args=(nbtypes.float32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert func.name == "mul" + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "float", + "float", + ] + assert func.return_type.unqualified_non_ref_type_name == "float" + + +def test_unmappable_numba_arg_skips_overload(deduction_decls): + """Skip overloads when Numba args cannot map to C++ types.""" + overloads = _get_function_templates(deduction_decls, "add") + unmappable = nbtypes.float32[:] + specialized, intent_errors = deduce_templated_overloads( + qualname="add", + overloads=overloads, + args=(unmappable, unmappable), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] From 3393c88c724c88eabb65b24f3fd6045100c6ee60 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 23 Jan 2026 12:58:47 -0800 Subject: [PATCH 02/11] add test for out_return intent kind --- numbast/src/numbast/types.py | 2 ++ numbast/tests/test_deduction.py | 19 +++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 53992597..025a1ab8 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -142,6 +142,8 @@ def to_numba_arg_type(ast_type) -> nbtypes.Type: def to_c_type_str(nbty: nbtypes.Type) -> str: + if isinstance(nbty, nbtypes.CPointer): + return f"{to_c_type_str(nbty.dtype)}*" if nbty not in NUMBA_TO_CTYPE_MAPS: raise ValueError( f"Unknown numba type attempted to converted into ctype: {nbty}" diff --git a/numbast/tests/test_deduction.py b/numbast/tests/test_deduction.py index d7d32368..7b3d284f 100644 --- a/numbast/tests/test_deduction.py +++ b/numbast/tests/test_deduction.py @@ -169,14 +169,25 @@ def test_return_only_placeholder_skipped(deduction_decls): assert specialized == [] -def test_pass_ptr_override_deduces_from_pointer(deduction_decls): - """Allow out_ptr overrides to pass pointers for reference params.""" +@pytest.mark.parametrize( + ("overrides", "args"), + [ + ({"out": "out_ptr"}, (nbtypes.CPointer(nbtypes.int32), nbtypes.int32)), + ( + {"out": "inout_ptr"}, + (nbtypes.CPointer(nbtypes.int32), nbtypes.int32), + ), + ({"out": "out_return"}, (nbtypes.int32,)), + ], +) +def test_override_intents_adjust_deduction(deduction_decls, overrides, args): + """Allow overrides (including out_return) to adjust deduction behavior.""" overloads = _get_function_templates(deduction_decls, "store_ref") specialized, intent_errors = deduce_templated_overloads( qualname="store_ref", overloads=overloads, - args=(nbtypes.CPointer(nbtypes.int32), nbtypes.int32), - overrides={"out": "out_ptr"}, + args=args, + overrides=overrides, ) assert intent_errors == [] From 833c46a3c9463bca70558ef4c137485821764cfe Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 23 Jan 2026 13:54:12 -0800 Subject: [PATCH 03/11] install g++ for pre-commit check --- ci/check_style.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ci/check_style.sh b/ci/check_style.sh index b6ab5ed0..3e0c604f 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -4,6 +4,10 @@ set -euo pipefail +# Install pre-commit requirements +sudo apt-get update +sudo apt-get install -y g++ + pip install pre-commit # Run pre-commit checks From 2e530b96dc5143e6d896127bbe2590aa9067fec0 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 23 Jan 2026 13:58:59 -0800 Subject: [PATCH 04/11] install g++ without sudo --- ci/check_style.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/check_style.sh b/ci/check_style.sh index 3e0c604f..d57071b7 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -5,8 +5,8 @@ set -euo pipefail # Install pre-commit requirements -sudo apt-get update -sudo apt-get install -y g++ +apt-get update +apt-get install -y g++ pip install pre-commit From 93c12354442746b9e43dce789d955227c88c9d2c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 13:49:44 -0800 Subject: [PATCH 05/11] collapse something like int[42]* into int* when deducting type and selecting overloads --- numbast/src/numbast/deduction.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py index 92706056..d4ffe434 100644 --- a/numbast/src/numbast/deduction.py +++ b/numbast/src/numbast/deduction.py @@ -18,6 +18,7 @@ _SPACE_RE = re.compile(r"\s+") +_ARRAY_SUFFIX_RE = re.compile(r"^(?P.+?)\s*\[(?P[^\]]*)\]\s*$") _DEBUG_ENV_VAR = "NUMBAST_TAD_DEBUG" @@ -48,6 +49,11 @@ def _normalize_cxx_type_str(type_str: str) -> str: def _apply_pass_ptr(cxx_param: str, pass_ptr: bool) -> str: if not pass_ptr: return cxx_param + array_match = _ARRAY_SUFFIX_RE.match(cxx_param) + if array_match: + # Arrays decay to element pointers when passed by value. + base = array_match.group("base").strip() + return _normalize_cxx_type_str(f"{base}*") if "*" in cxx_param: return cxx_param return _normalize_cxx_type_str(f"{cxx_param}*") From d912c9ccfa0d1071ce2dfb5fbb44e3c641e8a6ae Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 14:16:46 -0800 Subject: [PATCH 06/11] add test for int[8]* use case --- numbast/tests/test_deduction.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/numbast/tests/test_deduction.py b/numbast/tests/test_deduction.py index 7b3d284f..57428dec 100644 --- a/numbast/tests/test_deduction.py +++ b/numbast/tests/test_deduction.py @@ -45,6 +45,11 @@ template __device__ void write(T &out, T value) const { out = value; } }; + + template + __device__ void fill_array(T (&arr)[4], T value) { + for (int i = 0; i < 4; ++i) arr[i] = value; + } """ ) @@ -252,3 +257,22 @@ def test_unmappable_numba_arg_skips_overload(deduction_decls): assert intent_errors == [] assert specialized == [] + + +def test_array_param_decays_to_pointer_with_pass_ptr(deduction_decls): + """Array types like T[N] decay to T* when pass_ptr is applied.""" + overloads = _get_function_templates(deduction_decls, "fill_array") + specialized, intent_errors = deduce_templated_overloads( + qualname="fill_array", + overloads=overloads, + args=(nbtypes.CPointer(nbtypes.int32), nbtypes.int32), + overrides={"arr": "out_ptr"}, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + # After specialization, params should be int[4] (reference) and int + assert func.params[0].type_.unqualified_non_ref_type_name == "int[4]" + assert func.params[0].type_.is_left_reference() + assert func.params[1].type_.unqualified_non_ref_type_name == "int" From fa20340577ec3f223edbaeb384ce7475ecee1367 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 14:17:04 -0800 Subject: [PATCH 07/11] narrow expected exception to ValueError --- numbast/src/numbast/deduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py index d4ffe434..58af41ac 100644 --- a/numbast/src/numbast/deduction.py +++ b/numbast/src/numbast/deduction.py @@ -71,7 +71,7 @@ def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None: arg = _normalize_numba_arg_type(arg) try: return _normalize_cxx_type_str(to_c_type_str(arg)) - except Exception: + except ValueError: return None From 1347d320604f8ea1f7277b3678a3f69a26700ccf Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 14:20:19 -0800 Subject: [PATCH 08/11] return nbtypes.undefined for unknown types --- numbast/src/numbast/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 025a1ab8..6eadce02 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -126,7 +126,7 @@ def to_numba_type(ty: str): base_ty, size = is_array_type.groups() return nbtypes.UniTuple(to_numba_type(base_ty), int(size)) - return CTYPE_MAPS[ty] + return CTYPE_MAPS.get(ty, nbtypes.undefined) def to_numba_arg_type(ast_type) -> nbtypes.Type: From 01b08c1e08342a341e707fdc0a430eed8ea9e1b5 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 14:52:08 -0800 Subject: [PATCH 09/11] address bugs where placeholder appearing in multiple locations --- numbast/src/numbast/deduction.py | 50 +++++++++++++++++++++++++++++--- numbast/tests/test_deduction.py | 15 +++++++++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py index 58af41ac..c9dd1277 100644 --- a/numbast/src/numbast/deduction.py +++ b/numbast/src/numbast/deduction.py @@ -86,15 +86,57 @@ def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool: def _deduce_from_type_pattern( cxx_type: str, arg_cxx: str, placeholders: Iterable[str] ) -> dict[str, str] | None: + """Deduce template placeholder values from a C++ type pattern. + + This helper treats ``cxx_type`` as a pattern that may contain one or more + placeholder names (e.g., ``T``) supplied in ``placeholders``. It scans + ``cxx_type`` left-to-right, replacing each placeholder occurrence with a + non-greedy capture group, and records the placeholder name for every + occurrence. The resulting regex is matched against ``arg_cxx``. If the + match fails, no deduction is possible and ``None`` is returned. + + For each captured group, the value is stripped of whitespace. Empty captures + are rejected. When a placeholder appears multiple times, all occurrences + must resolve to the same concrete type; conflicting values cause this + function to return ``None``. On success, a mapping from placeholder name to + deduced type string is returned. + + Example: + # Pattern from a templated parameter. + cxx_type = "Pair" + arg_cxx = "Pair" + placeholders = ["T", "U"] + + # Deduce placeholder values from the concrete argument type. + deduced = _deduce_from_type_pattern(cxx_type, arg_cxx, placeholders) + # deduced == {"T": "float", "U": "int"} + + # The same placeholder can appear multiple times; all captures must agree. + repeat = _deduce_from_type_pattern("Pair", "Pair", ["T"]) + # repeat == {"T": "int"} + """ placeholders_in_type = [p for p in placeholders if p in cxx_type] if not placeholders_in_type: return None - pattern = re.escape(cxx_type) + unique_placeholders = list(dict.fromkeys(placeholders_in_type)) + placeholder_patterns = sorted(unique_placeholders, key=len, reverse=True) + placeholder_regex = re.compile( + "|".join(re.escape(ph) for ph in placeholder_patterns) + ) + order: list[str] = [] - for ph in placeholders_in_type: - pattern = pattern.replace(re.escape(ph), r"(.*?)") - order.append(ph) + pattern_parts: list[str] = [] + last_index = 0 + for ph_match in placeholder_regex.finditer(cxx_type): + start, end = ph_match.span() + if start > last_index: + pattern_parts.append(re.escape(cxx_type[last_index:start])) + pattern_parts.append(r"(.*?)") + order.append(ph_match.group(0)) + last_index = end + pattern_parts.append(re.escape(cxx_type[last_index:])) + pattern = "".join(pattern_parts) match = re.fullmatch(pattern, arg_cxx) if not match: diff --git a/numbast/tests/test_deduction.py b/numbast/tests/test_deduction.py index 57428dec..db69cb49 100644 --- a/numbast/tests/test_deduction.py +++ b/numbast/tests/test_deduction.py @@ -10,7 +10,10 @@ from ast_canopy import parse_declarations_from_source from numba.cuda import types as nbtypes -from numbast.deduction import deduce_templated_overloads +from numbast.deduction import ( + _deduce_from_type_pattern, + deduce_templated_overloads, +) _CXX_SOURCE = textwrap.dedent( @@ -130,6 +133,16 @@ def test_conflicting_deduction_skips_overload(deduction_decls): assert specialized == [] +def test_repeated_placeholder_conflict_in_type_pattern(): + """Reject types when repeated placeholders deduce conflicting values.""" + deduced = _deduce_from_type_pattern( + "pair", + "pair", + ["T"], + ) + assert deduced is None + + def test_non_templated_param_requires_match(deduction_decls): """Require exact matches for non-templated parameters.""" overloads = _get_function_templates(deduction_decls, "add_int") From 2c61b49e9e827c9fcaec41b9f0ae3bdabe5d71c3 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 15:00:40 -0800 Subject: [PATCH 10/11] narrow expected errors --- numbast/src/numbast/deduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py index c9dd1277..3d231b22 100644 --- a/numbast/src/numbast/deduction.py +++ b/numbast/src/numbast/deduction.py @@ -284,7 +284,7 @@ def deduce_templated_overloads( overrides=overrides, allow_out_return=True, ) - except Exception as exc: + except (ValueError, TypeError) as exc: intent_errors.append(exc) _debug_print( debug, From 30290ac98d30fd4d20c18897ac9357fb5664ac87 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 26 Jan 2026 15:15:18 -0800 Subject: [PATCH 11/11] add FIXME --- numbast/src/numbast/types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 6eadce02..e0f9c876 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -126,6 +126,8 @@ def to_numba_type(ty: str): base_ty, size = is_array_type.groups() return nbtypes.UniTuple(to_numba_type(base_ty), int(size)) + # FIXME: Currently returning an undefined type. But in a future PR, we will + # return an opaque type instead. return CTYPE_MAPS.get(ty, nbtypes.undefined)