diff --git a/ci/check_style.sh b/ci/check_style.sh
index b6ab5ed0..d57071b7 100755
--- a/ci/check_style.sh
+++ b/ci/check_style.sh
@@ -4,6 +4,10 @@
set -euo pipefail
+# Install pre-commit requirements
+apt-get update
+apt-get install -y g++
+
pip install pre-commit
# Run pre-commit checks
diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py
new file mode 100644
index 00000000..3d231b22
--- /dev/null
+++ b/numbast/src/numbast/deduction.py
@@ -0,0 +1,408 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+from typing import Iterable
+import os
+import re
+
+from ast_canopy import pylibastcanopy
+from ast_canopy.decl import Function, FunctionTemplate, StructMethod
+
+from numba.cuda import types as nbtypes
+
+from numbast.intent import compute_intent_plan
+from numbast.types import to_c_type_str, to_numba_type
+
+
+_SPACE_RE = re.compile(r"\s+")
+_ARRAY_SUFFIX_RE = re.compile(r"^(?P.+?)\s*\[(?P[^\]]*)\]\s*$")
+_DEBUG_ENV_VAR = "NUMBAST_TAD_DEBUG"
+
+
+def _debug_enabled(debug: bool | None) -> bool:
+ if debug is not None:
+ return debug
+ value = os.environ.get(_DEBUG_ENV_VAR, "")
+ return value.strip().lower() in ("1", "true", "yes", "on")
+
+
+def _debug_print(debug: bool | None, msg: str) -> None:
+ if _debug_enabled(debug):
+ print(f"[numbast.deduction] {msg}")
+
+
+def _normalize_cxx_type_str(type_str: str) -> str:
+ """Normalize C++ type strings for comparison."""
+ if not type_str:
+ return type_str
+ type_str = type_str.strip()
+ type_str = _SPACE_RE.sub(" ", type_str)
+ # Remove whitespace around pointer/reference markers.
+ type_str = type_str.replace(" *", "*").replace("* ", "*")
+ type_str = type_str.replace(" &", "&").replace("& ", "&")
+ return type_str
+
+
+def _apply_pass_ptr(cxx_param: str, pass_ptr: bool) -> str:
+ if not pass_ptr:
+ return cxx_param
+ array_match = _ARRAY_SUFFIX_RE.match(cxx_param)
+ if array_match:
+ # Arrays decay to element pointers when passed by value.
+ base = array_match.group("base").strip()
+ return _normalize_cxx_type_str(f"{base}*")
+ if "*" in cxx_param:
+ return cxx_param
+ return _normalize_cxx_type_str(f"{cxx_param}*")
+
+
+def _normalize_numba_arg_type(arg: nbtypes.Type) -> nbtypes.Type:
+ if isinstance(arg, nbtypes.Literal):
+ literal_type = getattr(arg, "literal_type", None)
+ if literal_type is not None:
+ return literal_type
+ return arg
+
+
+def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None:
+ arg = _normalize_numba_arg_type(arg)
+ try:
+ return _normalize_cxx_type_str(to_c_type_str(arg))
+ except ValueError:
+ return None
+
+
+def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
+ """Best-effort compatibility check for non-templated parameters."""
+ nb_expected = to_numba_type(cxx_type)
+ if nb_expected is nbtypes.undefined:
+ return True
+ return nb_expected == _normalize_numba_arg_type(arg)
+
+
+def _deduce_from_type_pattern(
+ cxx_type: str, arg_cxx: str, placeholders: Iterable[str]
+) -> dict[str, str] | None:
+ """Deduce template placeholder values from a C++ type pattern.
+
+ This helper treats ``cxx_type`` as a pattern that may contain one or more
+ placeholder names (e.g., ``T``) supplied in ``placeholders``. It scans
+ ``cxx_type`` left-to-right, replacing each placeholder occurrence with a
+ non-greedy capture group, and records the placeholder name for every
+ occurrence. The resulting regex is matched against ``arg_cxx``. If the
+ match fails, no deduction is possible and ``None`` is returned.
+
+ For each captured group, the value is stripped of whitespace. Empty captures
+ are rejected. When a placeholder appears multiple times, all occurrences
+ must resolve to the same concrete type; conflicting values cause this
+ function to return ``None``. On success, a mapping from placeholder name to
+ deduced type string is returned.
+
+ Example:
+ # Pattern from a templated parameter.
+ cxx_type = "Pair"
+ arg_cxx = "Pair"
+ placeholders = ["T", "U"]
+
+ # Deduce placeholder values from the concrete argument type.
+ deduced = _deduce_from_type_pattern(cxx_type, arg_cxx, placeholders)
+ # deduced == {"T": "float", "U": "int"}
+
+ # The same placeholder can appear multiple times; all captures must agree.
+ repeat = _deduce_from_type_pattern("Pair", "Pair", ["T"])
+ # repeat == {"T": "int"}
+ """
+ placeholders_in_type = [p for p in placeholders if p in cxx_type]
+ if not placeholders_in_type:
+ return None
+
+ unique_placeholders = list(dict.fromkeys(placeholders_in_type))
+ placeholder_patterns = sorted(unique_placeholders, key=len, reverse=True)
+ placeholder_regex = re.compile(
+ "|".join(re.escape(ph) for ph in placeholder_patterns)
+ )
+
+ order: list[str] = []
+ pattern_parts: list[str] = []
+ last_index = 0
+ for ph_match in placeholder_regex.finditer(cxx_type):
+ start, end = ph_match.span()
+ if start > last_index:
+ pattern_parts.append(re.escape(cxx_type[last_index:start]))
+ pattern_parts.append(r"(.*?)")
+ order.append(ph_match.group(0))
+ last_index = end
+ pattern_parts.append(re.escape(cxx_type[last_index:]))
+ pattern = "".join(pattern_parts)
+
+ match = re.fullmatch(pattern, arg_cxx)
+ if not match:
+ return None
+
+ deduced: dict[str, str] = {}
+ for ph, value in zip(order, match.groups()):
+ value = value.strip()
+ if not value:
+ return None
+ if ph in deduced and deduced[ph] != value:
+ return None
+ deduced[ph] = value
+ return deduced
+
+
+def _replace_placeholders(type_str: str, replacements: dict[str, str]) -> str:
+ for key, value in replacements.items():
+ type_str = type_str.replace(key, value)
+ return type_str
+
+
+def _specialize_type(
+ ast_type: pylibastcanopy.Type, replacements: dict[str, str]
+) -> pylibastcanopy.Type:
+ new_name = _replace_placeholders(ast_type.name, replacements)
+ new_unqualified = _replace_placeholders(
+ ast_type.unqualified_non_ref_type_name, replacements
+ )
+ return pylibastcanopy.Type(
+ new_name,
+ new_unqualified,
+ ast_type.is_right_reference(),
+ ast_type.is_left_reference(),
+ )
+
+
+def _specialize_function(
+ func: Function, replacements: dict[str, str]
+) -> Function:
+ new_return = _specialize_type(func.return_type, replacements)
+ new_params = [
+ pylibastcanopy.ParamVar(p.name, _specialize_type(p.type_, replacements))
+ for p in func.params
+ ]
+
+ if isinstance(func, StructMethod):
+ return StructMethod(
+ func.name,
+ func.qual_name,
+ new_return,
+ new_params,
+ func.kind,
+ func.exec_space,
+ func.is_constexpr,
+ func.is_move_constructor,
+ func.mangled_name,
+ func.attributes,
+ func.parse_entry_point,
+ )
+
+ return Function(
+ func.name,
+ func.qual_name,
+ new_return,
+ new_params,
+ func.exec_space,
+ func.is_constexpr,
+ func.mangled_name,
+ func.attributes,
+ func.parse_entry_point,
+ )
+
+
+def _clone_function_template(
+ templ: FunctionTemplate, func: Function
+) -> FunctionTemplate:
+ return FunctionTemplate(
+ templ.template_parameters,
+ templ.num_min_required_args,
+ func,
+ templ.qual_name,
+ templ.parse_entry_point,
+ )
+
+
+def _unresolved_placeholders(
+ func: Function, placeholder_names: Iterable[str]
+) -> bool:
+ placeholders = tuple(placeholder_names)
+ if any(
+ p in func.return_type.unqualified_non_ref_type_name
+ for p in placeholders
+ ):
+ return True
+ for param in func.params:
+ if any(
+ p in param.type_.unqualified_non_ref_type_name for p in placeholders
+ ):
+ return True
+ return False
+
+
+def deduce_templated_overloads(
+ *,
+ qualname: str,
+ overloads: list[FunctionTemplate],
+ args: tuple[nbtypes.Type, ...],
+ overrides: dict | None = None,
+ debug: bool | None = None,
+) -> tuple[list[FunctionTemplate], list[Exception]]:
+ """
+ Perform template argument deduction for templated method overloads.
+
+ Returns a list of FunctionTemplate objects with fully-specialized
+ Function/Method types, plus any arg_intent-related errors encountered
+ while computing visible arity.
+
+ Enable debug output by passing debug=True or setting the
+ NUMBAST_TAD_DEBUG=1 environment variable.
+ """
+ specialized: list[FunctionTemplate] = []
+ intent_errors: list[Exception] = []
+
+ _debug_print(
+ debug,
+ f"begin: {qualname}, overloads={len(overloads)}, args={len(args)}, "
+ f"overrides={'yes' if overrides else 'no'}",
+ )
+
+ for idx, templ in enumerate(overloads):
+ _debug_print(
+ debug,
+ f"overload[{idx}] {templ.function.name}: "
+ f"params={[p.type_.unqualified_non_ref_type_name for p in templ.function.params]}",
+ )
+ if overrides is None:
+ visible_param_indices = tuple(range(len(templ.function.params)))
+ pass_ptr_mask = tuple(False for _ in visible_param_indices)
+ else:
+ try:
+ plan = compute_intent_plan(
+ params=templ.function.params,
+ param_types=templ.function.param_types,
+ overrides=overrides,
+ allow_out_return=True,
+ )
+ except (ValueError, TypeError) as exc:
+ intent_errors.append(exc)
+ _debug_print(
+ debug,
+ f" intent plan error: {exc}",
+ )
+ continue
+ visible_param_indices = plan.visible_param_indices
+ pass_ptr_mask = plan.pass_ptr_mask
+ _debug_print(
+ debug,
+ " intent plan: "
+ f"visible={plan.visible_param_indices}, "
+ f"out_return={plan.out_return_indices}, "
+ f"pass_ptr={plan.pass_ptr_mask}",
+ )
+
+ if len(visible_param_indices) != len(args):
+ _debug_print(
+ debug,
+ " skip: visible arity mismatch "
+ f"(visible={len(visible_param_indices)} vs args={len(args)})",
+ )
+ continue
+
+ placeholder_names = [
+ tp.type_.name
+ for tp in templ.template_parameters
+ if tp.kind == pylibastcanopy.template_param_kind.type_
+ ]
+ _debug_print(
+ debug,
+ f" placeholders={placeholder_names or 'none'}",
+ )
+ mapping: dict[str, str] = {}
+ failed = False
+
+ for vis_pos, (arg, param_idx) in enumerate(
+ zip(args, visible_param_indices)
+ ):
+ param = templ.function.params[param_idx]
+ cxx_param = _normalize_cxx_type_str(
+ param.type_.unqualified_non_ref_type_name
+ )
+ pass_ptr = bool(pass_ptr_mask[vis_pos])
+ cxx_param = _apply_pass_ptr(cxx_param, pass_ptr)
+ arg_cxx = _numba_arg_to_cxx_type(arg)
+ if arg_cxx is None:
+ _debug_print(
+ debug,
+ f" arg[{param_idx}] {param.name}: "
+ f"numba={arg} could not map to C++ type",
+ )
+ failed = True
+ break
+
+ _debug_print(
+ debug,
+ f" arg[{param_idx}] {param.name}: "
+ f"param={cxx_param}, arg={arg_cxx}, pass_ptr={pass_ptr}",
+ )
+
+ deduced = _deduce_from_type_pattern(
+ cxx_param, arg_cxx, placeholder_names
+ )
+ if deduced is None:
+ if not _param_type_matches_arg(cxx_param, arg):
+ _debug_print(
+ debug,
+ f" mismatch: param={cxx_param}, arg={arg}",
+ )
+ failed = True
+ break
+ _debug_print(
+ debug,
+ " no deduction needed; param matches arg",
+ )
+ continue
+
+ _debug_print(
+ debug,
+ f" deduced={deduced}",
+ )
+
+ for key, value in deduced.items():
+ prev = mapping.get(key)
+ if prev is not None and prev != value:
+ _debug_print(
+ debug,
+ f" conflict: {key}={prev} vs {value}",
+ )
+ failed = True
+ break
+ mapping[key] = value
+ if failed:
+ break
+
+ if failed:
+ _debug_print(debug, " skip: deduction failed")
+ continue
+
+ _debug_print(debug, f" mapping={mapping}")
+ specialized_func = _specialize_function(templ.function, mapping)
+ if _unresolved_placeholders(specialized_func, placeholder_names):
+ _debug_print(
+ debug, " skip: unresolved placeholders after specialize"
+ )
+ continue
+
+ _debug_print(
+ debug,
+ " specialized: "
+ f"return={specialized_func.return_type.unqualified_non_ref_type_name}, "
+ f"params={[p.type_.unqualified_non_ref_type_name for p in specialized_func.params]}",
+ )
+ specialized.append(_clone_function_template(templ, specialized_func))
+
+ _debug_print(
+ debug,
+ f"end: {qualname}, specialized={len(specialized)}, intent_errors={len(intent_errors)}",
+ )
+ return specialized, intent_errors
diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py
index 53992597..e0f9c876 100644
--- a/numbast/src/numbast/types.py
+++ b/numbast/src/numbast/types.py
@@ -126,7 +126,9 @@ def to_numba_type(ty: str):
base_ty, size = is_array_type.groups()
return nbtypes.UniTuple(to_numba_type(base_ty), int(size))
- return CTYPE_MAPS[ty]
+ # FIXME: Currently returning an undefined type. But in a future PR, we will
+ # return an opaque type instead.
+ return CTYPE_MAPS.get(ty, nbtypes.undefined)
def to_numba_arg_type(ast_type) -> nbtypes.Type:
@@ -142,6 +144,8 @@ def to_numba_arg_type(ast_type) -> nbtypes.Type:
def to_c_type_str(nbty: nbtypes.Type) -> str:
+ if isinstance(nbty, nbtypes.CPointer):
+ return f"{to_c_type_str(nbty.dtype)}*"
if nbty not in NUMBA_TO_CTYPE_MAPS:
raise ValueError(
f"Unknown numba type attempted to converted into ctype: {nbty}"
diff --git a/numbast/tests/test_deduction.py b/numbast/tests/test_deduction.py
new file mode 100644
index 00000000..db69cb49
--- /dev/null
+++ b/numbast/tests/test_deduction.py
@@ -0,0 +1,291 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations
+
+import textwrap
+
+import pytest
+
+from ast_canopy import parse_declarations_from_source
+from numba.cuda import types as nbtypes
+
+from numbast.deduction import (
+ _deduce_from_type_pattern,
+ deduce_templated_overloads,
+)
+
+
+_CXX_SOURCE = textwrap.dedent(
+ """\
+ #pragma once
+
+ template
+ __device__ T add(T a, T b) { return a + b; }
+
+ template
+ __device__ T add(T a, T b, T c) { return a + b + c; }
+
+ template
+ __device__ T add_int(int a, T b) { return a + b; }
+
+ template
+ __device__ void store_ptr(T *out, T value) { *out = value; }
+
+ template
+ __device__ void store_ref(T &out, T value) { out = value; }
+
+ template
+ __device__ T return_only();
+
+ template
+ __device__ void bad_out(T value) { (void)value; }
+
+ struct Box {
+ template
+ __device__ T mul(T a, T b) const { return a * b; }
+
+ template
+ __device__ void write(T &out, T value) const { out = value; }
+ };
+
+ template
+ __device__ void fill_array(T (&arr)[4], T value) {
+ for (int i = 0; i < 4; ++i) arr[i] = value;
+ }
+ """
+)
+
+
+@pytest.fixture(scope="module")
+def deduction_decls(tmp_path_factory):
+ tmp_dir = tmp_path_factory.mktemp("deduction")
+ header_path = tmp_dir / "deduction_sample.cuh"
+ header_path.write_text(_CXX_SOURCE, encoding="utf-8")
+ decls = parse_declarations_from_source(
+ str(header_path),
+ [str(header_path)],
+ "sm_80",
+ verbose=False,
+ )
+ return decls
+
+
+def _get_function_templates(decls, name: str):
+ templs = [
+ templ
+ for templ in decls.function_templates
+ if templ.function.name == name
+ ]
+ if not templs:
+ raise AssertionError(f"No function templates found for {name!r}")
+ return templs
+
+
+def _get_struct_method_templates(decls, struct_name: str, method_name: str):
+ for struct in decls.structs:
+ if struct.name == struct_name:
+ templs = [
+ templ
+ for templ in struct.templated_methods
+ if templ.function.name == method_name
+ ]
+ if not templs:
+ raise AssertionError(
+ f"No templated methods found for {struct_name}::{method_name}"
+ )
+ return templs
+ raise AssertionError(f"Struct {struct_name!r} not found")
+
+
+def test_overload_arity_selection(deduction_decls):
+ """Select the overload that matches the visible argument arity."""
+ overloads = _get_function_templates(deduction_decls, "add")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="add",
+ overloads=overloads,
+ args=(nbtypes.int32, nbtypes.int32),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert len(specialized) == 1
+ func = specialized[0].function
+ assert len(func.params) == 2
+ assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [
+ "int",
+ "int",
+ ]
+ assert func.return_type.unqualified_non_ref_type_name == "int"
+
+
+def test_conflicting_deduction_skips_overload(deduction_decls):
+ """Skip overloads when template placeholders deduce conflicting types."""
+ overloads = _get_function_templates(deduction_decls, "add")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="add",
+ overloads=overloads,
+ args=(nbtypes.int32, nbtypes.float32),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert specialized == []
+
+
+def test_repeated_placeholder_conflict_in_type_pattern():
+ """Reject types when repeated placeholders deduce conflicting values."""
+ deduced = _deduce_from_type_pattern(
+ "pair",
+ "pair",
+ ["T"],
+ )
+ assert deduced is None
+
+
+def test_non_templated_param_requires_match(deduction_decls):
+ """Require exact matches for non-templated parameters."""
+ overloads = _get_function_templates(deduction_decls, "add_int")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="add_int",
+ overloads=overloads,
+ args=(nbtypes.int32, nbtypes.float32),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert len(specialized) == 1
+ func = specialized[0].function
+ assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [
+ "int",
+ "float",
+ ]
+ assert func.return_type.unqualified_non_ref_type_name == "float"
+
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="add_int",
+ overloads=overloads,
+ args=(nbtypes.float32, nbtypes.float32),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert specialized == []
+
+
+def test_return_only_placeholder_skipped(deduction_decls):
+ """Skip overloads with unresolved placeholders only in return type."""
+ overloads = _get_function_templates(deduction_decls, "return_only")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="return_only",
+ overloads=overloads,
+ args=(),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert specialized == []
+
+
+@pytest.mark.parametrize(
+ ("overrides", "args"),
+ [
+ ({"out": "out_ptr"}, (nbtypes.CPointer(nbtypes.int32), nbtypes.int32)),
+ (
+ {"out": "inout_ptr"},
+ (nbtypes.CPointer(nbtypes.int32), nbtypes.int32),
+ ),
+ ({"out": "out_return"}, (nbtypes.int32,)),
+ ],
+)
+def test_override_intents_adjust_deduction(deduction_decls, overrides, args):
+ """Allow overrides (including out_return) to adjust deduction behavior."""
+ overloads = _get_function_templates(deduction_decls, "store_ref")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="store_ref",
+ overloads=overloads,
+ args=args,
+ overrides=overrides,
+ )
+
+ assert intent_errors == []
+ assert len(specialized) == 1
+ func = specialized[0].function
+ assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [
+ "int",
+ "int",
+ ]
+ assert func.params[0].type_.is_left_reference()
+ assert not func.params[1].type_.is_left_reference()
+
+
+def test_invalid_override_surfaces_error(deduction_decls):
+ """Return intent errors when overrides are incompatible with param types."""
+ overloads = _get_function_templates(deduction_decls, "bad_out")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="bad_out",
+ overloads=overloads,
+ args=(nbtypes.int32,),
+ overrides={"value": "out_ptr"},
+ )
+
+ assert specialized == []
+ assert intent_errors
+ err = intent_errors[0]
+ assert isinstance(err, ValueError)
+ assert "reference parameters" in str(err)
+
+
+def test_struct_method_specialization(deduction_decls):
+ """Specialize templated struct methods into concrete types."""
+ overloads = _get_struct_method_templates(deduction_decls, "Box", "mul")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="Box.mul",
+ overloads=overloads,
+ args=(nbtypes.float32, nbtypes.float32),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert len(specialized) == 1
+ func = specialized[0].function
+ assert func.name == "mul"
+ assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [
+ "float",
+ "float",
+ ]
+ assert func.return_type.unqualified_non_ref_type_name == "float"
+
+
+def test_unmappable_numba_arg_skips_overload(deduction_decls):
+ """Skip overloads when Numba args cannot map to C++ types."""
+ overloads = _get_function_templates(deduction_decls, "add")
+ unmappable = nbtypes.float32[:]
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="add",
+ overloads=overloads,
+ args=(unmappable, unmappable),
+ overrides=None,
+ )
+
+ assert intent_errors == []
+ assert specialized == []
+
+
+def test_array_param_decays_to_pointer_with_pass_ptr(deduction_decls):
+ """Array types like T[N] decay to T* when pass_ptr is applied."""
+ overloads = _get_function_templates(deduction_decls, "fill_array")
+ specialized, intent_errors = deduce_templated_overloads(
+ qualname="fill_array",
+ overloads=overloads,
+ args=(nbtypes.CPointer(nbtypes.int32), nbtypes.int32),
+ overrides={"arr": "out_ptr"},
+ )
+
+ assert intent_errors == []
+ assert len(specialized) == 1
+ func = specialized[0].function
+ # After specialization, params should be int[4] (reference) and int
+ assert func.params[0].type_.unqualified_non_ref_type_name == "int[4]"
+ assert func.params[0].type_.is_left_reference()
+ assert func.params[1].type_.unqualified_non_ref_type_name == "int"