diff --git a/ci/check_style.sh b/ci/check_style.sh index b6ab5ed0..d57071b7 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -4,6 +4,10 @@ set -euo pipefail +# Install pre-commit requirements +apt-get update +apt-get install -y g++ + pip install pre-commit # Run pre-commit checks diff --git a/numbast/src/numbast/deduction.py b/numbast/src/numbast/deduction.py new file mode 100644 index 00000000..3d231b22 --- /dev/null +++ b/numbast/src/numbast/deduction.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Iterable +import os +import re + +from ast_canopy import pylibastcanopy +from ast_canopy.decl import Function, FunctionTemplate, StructMethod + +from numba.cuda import types as nbtypes + +from numbast.intent import compute_intent_plan +from numbast.types import to_c_type_str, to_numba_type + + +_SPACE_RE = re.compile(r"\s+") +_ARRAY_SUFFIX_RE = re.compile(r"^(?P.+?)\s*\[(?P[^\]]*)\]\s*$") +_DEBUG_ENV_VAR = "NUMBAST_TAD_DEBUG" + + +def _debug_enabled(debug: bool | None) -> bool: + if debug is not None: + return debug + value = os.environ.get(_DEBUG_ENV_VAR, "") + return value.strip().lower() in ("1", "true", "yes", "on") + + +def _debug_print(debug: bool | None, msg: str) -> None: + if _debug_enabled(debug): + print(f"[numbast.deduction] {msg}") + + +def _normalize_cxx_type_str(type_str: str) -> str: + """Normalize C++ type strings for comparison.""" + if not type_str: + return type_str + type_str = type_str.strip() + type_str = _SPACE_RE.sub(" ", type_str) + # Remove whitespace around pointer/reference markers. + type_str = type_str.replace(" *", "*").replace("* ", "*") + type_str = type_str.replace(" &", "&").replace("& ", "&") + return type_str + + +def _apply_pass_ptr(cxx_param: str, pass_ptr: bool) -> str: + if not pass_ptr: + return cxx_param + array_match = _ARRAY_SUFFIX_RE.match(cxx_param) + if array_match: + # Arrays decay to element pointers when passed by value. + base = array_match.group("base").strip() + return _normalize_cxx_type_str(f"{base}*") + if "*" in cxx_param: + return cxx_param + return _normalize_cxx_type_str(f"{cxx_param}*") + + +def _normalize_numba_arg_type(arg: nbtypes.Type) -> nbtypes.Type: + if isinstance(arg, nbtypes.Literal): + literal_type = getattr(arg, "literal_type", None) + if literal_type is not None: + return literal_type + return arg + + +def _numba_arg_to_cxx_type(arg: nbtypes.Type) -> str | None: + arg = _normalize_numba_arg_type(arg) + try: + return _normalize_cxx_type_str(to_c_type_str(arg)) + except ValueError: + return None + + +def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool: + """Best-effort compatibility check for non-templated parameters.""" + nb_expected = to_numba_type(cxx_type) + if nb_expected is nbtypes.undefined: + return True + return nb_expected == _normalize_numba_arg_type(arg) + + +def _deduce_from_type_pattern( + cxx_type: str, arg_cxx: str, placeholders: Iterable[str] +) -> dict[str, str] | None: + """Deduce template placeholder values from a C++ type pattern. + + This helper treats ``cxx_type`` as a pattern that may contain one or more + placeholder names (e.g., ``T``) supplied in ``placeholders``. It scans + ``cxx_type`` left-to-right, replacing each placeholder occurrence with a + non-greedy capture group, and records the placeholder name for every + occurrence. The resulting regex is matched against ``arg_cxx``. If the + match fails, no deduction is possible and ``None`` is returned. + + For each captured group, the value is stripped of whitespace. Empty captures + are rejected. When a placeholder appears multiple times, all occurrences + must resolve to the same concrete type; conflicting values cause this + function to return ``None``. On success, a mapping from placeholder name to + deduced type string is returned. + + Example: + # Pattern from a templated parameter. + cxx_type = "Pair" + arg_cxx = "Pair" + placeholders = ["T", "U"] + + # Deduce placeholder values from the concrete argument type. + deduced = _deduce_from_type_pattern(cxx_type, arg_cxx, placeholders) + # deduced == {"T": "float", "U": "int"} + + # The same placeholder can appear multiple times; all captures must agree. + repeat = _deduce_from_type_pattern("Pair", "Pair", ["T"]) + # repeat == {"T": "int"} + """ + placeholders_in_type = [p for p in placeholders if p in cxx_type] + if not placeholders_in_type: + return None + + unique_placeholders = list(dict.fromkeys(placeholders_in_type)) + placeholder_patterns = sorted(unique_placeholders, key=len, reverse=True) + placeholder_regex = re.compile( + "|".join(re.escape(ph) for ph in placeholder_patterns) + ) + + order: list[str] = [] + pattern_parts: list[str] = [] + last_index = 0 + for ph_match in placeholder_regex.finditer(cxx_type): + start, end = ph_match.span() + if start > last_index: + pattern_parts.append(re.escape(cxx_type[last_index:start])) + pattern_parts.append(r"(.*?)") + order.append(ph_match.group(0)) + last_index = end + pattern_parts.append(re.escape(cxx_type[last_index:])) + pattern = "".join(pattern_parts) + + match = re.fullmatch(pattern, arg_cxx) + if not match: + return None + + deduced: dict[str, str] = {} + for ph, value in zip(order, match.groups()): + value = value.strip() + if not value: + return None + if ph in deduced and deduced[ph] != value: + return None + deduced[ph] = value + return deduced + + +def _replace_placeholders(type_str: str, replacements: dict[str, str]) -> str: + for key, value in replacements.items(): + type_str = type_str.replace(key, value) + return type_str + + +def _specialize_type( + ast_type: pylibastcanopy.Type, replacements: dict[str, str] +) -> pylibastcanopy.Type: + new_name = _replace_placeholders(ast_type.name, replacements) + new_unqualified = _replace_placeholders( + ast_type.unqualified_non_ref_type_name, replacements + ) + return pylibastcanopy.Type( + new_name, + new_unqualified, + ast_type.is_right_reference(), + ast_type.is_left_reference(), + ) + + +def _specialize_function( + func: Function, replacements: dict[str, str] +) -> Function: + new_return = _specialize_type(func.return_type, replacements) + new_params = [ + pylibastcanopy.ParamVar(p.name, _specialize_type(p.type_, replacements)) + for p in func.params + ] + + if isinstance(func, StructMethod): + return StructMethod( + func.name, + func.qual_name, + new_return, + new_params, + func.kind, + func.exec_space, + func.is_constexpr, + func.is_move_constructor, + func.mangled_name, + func.attributes, + func.parse_entry_point, + ) + + return Function( + func.name, + func.qual_name, + new_return, + new_params, + func.exec_space, + func.is_constexpr, + func.mangled_name, + func.attributes, + func.parse_entry_point, + ) + + +def _clone_function_template( + templ: FunctionTemplate, func: Function +) -> FunctionTemplate: + return FunctionTemplate( + templ.template_parameters, + templ.num_min_required_args, + func, + templ.qual_name, + templ.parse_entry_point, + ) + + +def _unresolved_placeholders( + func: Function, placeholder_names: Iterable[str] +) -> bool: + placeholders = tuple(placeholder_names) + if any( + p in func.return_type.unqualified_non_ref_type_name + for p in placeholders + ): + return True + for param in func.params: + if any( + p in param.type_.unqualified_non_ref_type_name for p in placeholders + ): + return True + return False + + +def deduce_templated_overloads( + *, + qualname: str, + overloads: list[FunctionTemplate], + args: tuple[nbtypes.Type, ...], + overrides: dict | None = None, + debug: bool | None = None, +) -> tuple[list[FunctionTemplate], list[Exception]]: + """ + Perform template argument deduction for templated method overloads. + + Returns a list of FunctionTemplate objects with fully-specialized + Function/Method types, plus any arg_intent-related errors encountered + while computing visible arity. + + Enable debug output by passing debug=True or setting the + NUMBAST_TAD_DEBUG=1 environment variable. + """ + specialized: list[FunctionTemplate] = [] + intent_errors: list[Exception] = [] + + _debug_print( + debug, + f"begin: {qualname}, overloads={len(overloads)}, args={len(args)}, " + f"overrides={'yes' if overrides else 'no'}", + ) + + for idx, templ in enumerate(overloads): + _debug_print( + debug, + f"overload[{idx}] {templ.function.name}: " + f"params={[p.type_.unqualified_non_ref_type_name for p in templ.function.params]}", + ) + if overrides is None: + visible_param_indices = tuple(range(len(templ.function.params))) + pass_ptr_mask = tuple(False for _ in visible_param_indices) + else: + try: + plan = compute_intent_plan( + params=templ.function.params, + param_types=templ.function.param_types, + overrides=overrides, + allow_out_return=True, + ) + except (ValueError, TypeError) as exc: + intent_errors.append(exc) + _debug_print( + debug, + f" intent plan error: {exc}", + ) + continue + visible_param_indices = plan.visible_param_indices + pass_ptr_mask = plan.pass_ptr_mask + _debug_print( + debug, + " intent plan: " + f"visible={plan.visible_param_indices}, " + f"out_return={plan.out_return_indices}, " + f"pass_ptr={plan.pass_ptr_mask}", + ) + + if len(visible_param_indices) != len(args): + _debug_print( + debug, + " skip: visible arity mismatch " + f"(visible={len(visible_param_indices)} vs args={len(args)})", + ) + continue + + placeholder_names = [ + tp.type_.name + for tp in templ.template_parameters + if tp.kind == pylibastcanopy.template_param_kind.type_ + ] + _debug_print( + debug, + f" placeholders={placeholder_names or 'none'}", + ) + mapping: dict[str, str] = {} + failed = False + + for vis_pos, (arg, param_idx) in enumerate( + zip(args, visible_param_indices) + ): + param = templ.function.params[param_idx] + cxx_param = _normalize_cxx_type_str( + param.type_.unqualified_non_ref_type_name + ) + pass_ptr = bool(pass_ptr_mask[vis_pos]) + cxx_param = _apply_pass_ptr(cxx_param, pass_ptr) + arg_cxx = _numba_arg_to_cxx_type(arg) + if arg_cxx is None: + _debug_print( + debug, + f" arg[{param_idx}] {param.name}: " + f"numba={arg} could not map to C++ type", + ) + failed = True + break + + _debug_print( + debug, + f" arg[{param_idx}] {param.name}: " + f"param={cxx_param}, arg={arg_cxx}, pass_ptr={pass_ptr}", + ) + + deduced = _deduce_from_type_pattern( + cxx_param, arg_cxx, placeholder_names + ) + if deduced is None: + if not _param_type_matches_arg(cxx_param, arg): + _debug_print( + debug, + f" mismatch: param={cxx_param}, arg={arg}", + ) + failed = True + break + _debug_print( + debug, + " no deduction needed; param matches arg", + ) + continue + + _debug_print( + debug, + f" deduced={deduced}", + ) + + for key, value in deduced.items(): + prev = mapping.get(key) + if prev is not None and prev != value: + _debug_print( + debug, + f" conflict: {key}={prev} vs {value}", + ) + failed = True + break + mapping[key] = value + if failed: + break + + if failed: + _debug_print(debug, " skip: deduction failed") + continue + + _debug_print(debug, f" mapping={mapping}") + specialized_func = _specialize_function(templ.function, mapping) + if _unresolved_placeholders(specialized_func, placeholder_names): + _debug_print( + debug, " skip: unresolved placeholders after specialize" + ) + continue + + _debug_print( + debug, + " specialized: " + f"return={specialized_func.return_type.unqualified_non_ref_type_name}, " + f"params={[p.type_.unqualified_non_ref_type_name for p in specialized_func.params]}", + ) + specialized.append(_clone_function_template(templ, specialized_func)) + + _debug_print( + debug, + f"end: {qualname}, specialized={len(specialized)}, intent_errors={len(intent_errors)}", + ) + return specialized, intent_errors diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 53992597..e0f9c876 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -126,7 +126,9 @@ def to_numba_type(ty: str): base_ty, size = is_array_type.groups() return nbtypes.UniTuple(to_numba_type(base_ty), int(size)) - return CTYPE_MAPS[ty] + # FIXME: Currently returning an undefined type. But in a future PR, we will + # return an opaque type instead. + return CTYPE_MAPS.get(ty, nbtypes.undefined) def to_numba_arg_type(ast_type) -> nbtypes.Type: @@ -142,6 +144,8 @@ def to_numba_arg_type(ast_type) -> nbtypes.Type: def to_c_type_str(nbty: nbtypes.Type) -> str: + if isinstance(nbty, nbtypes.CPointer): + return f"{to_c_type_str(nbty.dtype)}*" if nbty not in NUMBA_TO_CTYPE_MAPS: raise ValueError( f"Unknown numba type attempted to converted into ctype: {nbty}" diff --git a/numbast/tests/test_deduction.py b/numbast/tests/test_deduction.py new file mode 100644 index 00000000..db69cb49 --- /dev/null +++ b/numbast/tests/test_deduction.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import textwrap + +import pytest + +from ast_canopy import parse_declarations_from_source +from numba.cuda import types as nbtypes + +from numbast.deduction import ( + _deduce_from_type_pattern, + deduce_templated_overloads, +) + + +_CXX_SOURCE = textwrap.dedent( + """\ + #pragma once + + template + __device__ T add(T a, T b) { return a + b; } + + template + __device__ T add(T a, T b, T c) { return a + b + c; } + + template + __device__ T add_int(int a, T b) { return a + b; } + + template + __device__ void store_ptr(T *out, T value) { *out = value; } + + template + __device__ void store_ref(T &out, T value) { out = value; } + + template + __device__ T return_only(); + + template + __device__ void bad_out(T value) { (void)value; } + + struct Box { + template + __device__ T mul(T a, T b) const { return a * b; } + + template + __device__ void write(T &out, T value) const { out = value; } + }; + + template + __device__ void fill_array(T (&arr)[4], T value) { + for (int i = 0; i < 4; ++i) arr[i] = value; + } + """ +) + + +@pytest.fixture(scope="module") +def deduction_decls(tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("deduction") + header_path = tmp_dir / "deduction_sample.cuh" + header_path.write_text(_CXX_SOURCE, encoding="utf-8") + decls = parse_declarations_from_source( + str(header_path), + [str(header_path)], + "sm_80", + verbose=False, + ) + return decls + + +def _get_function_templates(decls, name: str): + templs = [ + templ + for templ in decls.function_templates + if templ.function.name == name + ] + if not templs: + raise AssertionError(f"No function templates found for {name!r}") + return templs + + +def _get_struct_method_templates(decls, struct_name: str, method_name: str): + for struct in decls.structs: + if struct.name == struct_name: + templs = [ + templ + for templ in struct.templated_methods + if templ.function.name == method_name + ] + if not templs: + raise AssertionError( + f"No templated methods found for {struct_name}::{method_name}" + ) + return templs + raise AssertionError(f"Struct {struct_name!r} not found") + + +def test_overload_arity_selection(deduction_decls): + """Select the overload that matches the visible argument arity.""" + overloads = _get_function_templates(deduction_decls, "add") + specialized, intent_errors = deduce_templated_overloads( + qualname="add", + overloads=overloads, + args=(nbtypes.int32, nbtypes.int32), + overrides=None, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert len(func.params) == 2 + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "int", + "int", + ] + assert func.return_type.unqualified_non_ref_type_name == "int" + + +def test_conflicting_deduction_skips_overload(deduction_decls): + """Skip overloads when template placeholders deduce conflicting types.""" + overloads = _get_function_templates(deduction_decls, "add") + specialized, intent_errors = deduce_templated_overloads( + qualname="add", + overloads=overloads, + args=(nbtypes.int32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +def test_repeated_placeholder_conflict_in_type_pattern(): + """Reject types when repeated placeholders deduce conflicting values.""" + deduced = _deduce_from_type_pattern( + "pair", + "pair", + ["T"], + ) + assert deduced is None + + +def test_non_templated_param_requires_match(deduction_decls): + """Require exact matches for non-templated parameters.""" + overloads = _get_function_templates(deduction_decls, "add_int") + specialized, intent_errors = deduce_templated_overloads( + qualname="add_int", + overloads=overloads, + args=(nbtypes.int32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "int", + "float", + ] + assert func.return_type.unqualified_non_ref_type_name == "float" + + specialized, intent_errors = deduce_templated_overloads( + qualname="add_int", + overloads=overloads, + args=(nbtypes.float32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +def test_return_only_placeholder_skipped(deduction_decls): + """Skip overloads with unresolved placeholders only in return type.""" + overloads = _get_function_templates(deduction_decls, "return_only") + specialized, intent_errors = deduce_templated_overloads( + qualname="return_only", + overloads=overloads, + args=(), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +@pytest.mark.parametrize( + ("overrides", "args"), + [ + ({"out": "out_ptr"}, (nbtypes.CPointer(nbtypes.int32), nbtypes.int32)), + ( + {"out": "inout_ptr"}, + (nbtypes.CPointer(nbtypes.int32), nbtypes.int32), + ), + ({"out": "out_return"}, (nbtypes.int32,)), + ], +) +def test_override_intents_adjust_deduction(deduction_decls, overrides, args): + """Allow overrides (including out_return) to adjust deduction behavior.""" + overloads = _get_function_templates(deduction_decls, "store_ref") + specialized, intent_errors = deduce_templated_overloads( + qualname="store_ref", + overloads=overloads, + args=args, + overrides=overrides, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "int", + "int", + ] + assert func.params[0].type_.is_left_reference() + assert not func.params[1].type_.is_left_reference() + + +def test_invalid_override_surfaces_error(deduction_decls): + """Return intent errors when overrides are incompatible with param types.""" + overloads = _get_function_templates(deduction_decls, "bad_out") + specialized, intent_errors = deduce_templated_overloads( + qualname="bad_out", + overloads=overloads, + args=(nbtypes.int32,), + overrides={"value": "out_ptr"}, + ) + + assert specialized == [] + assert intent_errors + err = intent_errors[0] + assert isinstance(err, ValueError) + assert "reference parameters" in str(err) + + +def test_struct_method_specialization(deduction_decls): + """Specialize templated struct methods into concrete types.""" + overloads = _get_struct_method_templates(deduction_decls, "Box", "mul") + specialized, intent_errors = deduce_templated_overloads( + qualname="Box.mul", + overloads=overloads, + args=(nbtypes.float32, nbtypes.float32), + overrides=None, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + assert func.name == "mul" + assert [p.type_.unqualified_non_ref_type_name for p in func.params] == [ + "float", + "float", + ] + assert func.return_type.unqualified_non_ref_type_name == "float" + + +def test_unmappable_numba_arg_skips_overload(deduction_decls): + """Skip overloads when Numba args cannot map to C++ types.""" + overloads = _get_function_templates(deduction_decls, "add") + unmappable = nbtypes.float32[:] + specialized, intent_errors = deduce_templated_overloads( + qualname="add", + overloads=overloads, + args=(unmappable, unmappable), + overrides=None, + ) + + assert intent_errors == [] + assert specialized == [] + + +def test_array_param_decays_to_pointer_with_pass_ptr(deduction_decls): + """Array types like T[N] decay to T* when pass_ptr is applied.""" + overloads = _get_function_templates(deduction_decls, "fill_array") + specialized, intent_errors = deduce_templated_overloads( + qualname="fill_array", + overloads=overloads, + args=(nbtypes.CPointer(nbtypes.int32), nbtypes.int32), + overrides={"arr": "out_ptr"}, + ) + + assert intent_errors == [] + assert len(specialized) == 1 + func = specialized[0].function + # After specialization, params should be int[4] (reference) and int + assert func.params[0].type_.unqualified_non_ref_type_name == "int[4]" + assert func.params[0].type_.is_left_reference() + assert func.params[1].type_.unqualified_non_ref_type_name == "int"