Skip to content

Conversation

@isVoid
Copy link
Collaborator

@isVoid isVoid commented Jan 23, 2026

This adds a template argument deduction module and unit tests, and it incorporates the recently added argument intent concept from Numbast PR #278 so visible arity and pointer passing stay consistent across templated overloads.

At a high level, the module accepts a list of ast_canopy-parsed function templates plus optional intent overrides, and returns FunctionTemplate instances with deduced argument types for the original parameter names.

Tested cases:

  • Overload selection by visible arity.
    • C++: template <typename T> __device__ T add(T a, T b); and template <typename T> __device__ T add(T a, T b, T c);
    • Example: add(int, int) picks the two-arg overload and deduces T=int.
  • Conflicting placeholder deduction skips an overload.
    • C++: template <typename T> __device__ T add(T a, T b);
    • Example: add(int, float) yields no specialization because T conflicts.
  • Non-templated parameter types must match.
    • C++: template <typename T> __device__ T add_int(int a, T b);
    • Example: add_int(int, float) specializes, add_int(float, float) does not.
  • Return-only placeholders are skipped.
    • C++: template <typename T> __device__ T return_only();
    • Example: return_only<T>() cannot deduce T with no args, so it is skipped.
  • Intent overrides including out_return.
    • C++: template <typename T> __device__ void store_ref(T &out, T value);
    • Examples: store_ref(CPointer(int), int) with overrides={"out": "out_ptr"} or {"out": "inout_ptr"} deduces T=int; store_ref(int) with overrides={"out": "out_return"} also deduces T=int via hidden out param.
  • Invalid overrides surface as errors.
    • C++: template <typename T> __device__ void bad_out(T value);
    • Example: bad_out(value=out_ptr) reports a ValueError in intent_errors.
  • Struct method specialization.
    • C++: struct Box { template <typename T> __device__ T mul(T a, T b) const; };
    • Example: Box::mul(float, float) deduces T=float.
  • Unmappable Numba args are skipped.
    • C++: template <typename T> __device__ T add(T a, T b);
    • Example: add(float32[:], float32[:]) yields no specialization.

This currently is a standalone module that's not wired into any existing binding generation so that it can be tested and reasoned independently.

Summary by CodeRabbit

  • New Features

    • Template overload deduction with automatic specialization for templated functions and struct methods.
  • Improvements

    • Enhanced handling of pointer/reference parameters, placeholder resolution, conflict detection, and optional detailed debug tracing.
  • Bug Fixes

    • Corrected pointer-type handling so pointer types are recognized and represented properly.
  • Tests

    • Comprehensive tests for overload selection, conflicts, overrides, pointer/ref cases, and method specialization.
  • Chores

    • CI: ensure compiler availability before running style checks.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 23, 2026

📝 Walkthrough

Walkthrough

Adds a template-overload deduction engine that maps Numba types to C++-style strings, deduces template parameters from argument types, specializes templated functions and struct methods, adds tests for various deduction scenarios, tweaks pointer type handling, and ensures g++ is installed in CI.

Changes

Cohort / File(s) Summary
Core deduction module
numbast/src/numbast/deduction.py
New module implementing template-overload deduction: debug utilities, C++-style type normalization, Numba→C++ mapping, pattern-based template-parameter extraction, placeholder replacement, function/struct-method specialization, unresolved-placeholder detection, and public entry deduce_templated_overloads.
Deduction tests
numbast/tests/test_deduction.py
New pytest suite exercising deduce_templated_overloads: arity selection, conflicting deductions, non-templated param matching, return-placeholder skipping, pointer/override handling and related errors, struct-method specialization, and unmappable-argument cases; includes fixtures parsing CUDA header declarations.
Type conversion tweaks
numbast/src/numbast/types.py
to_numba_type now falls back to nbtypes.undefined for unknown entries; to_c_type_str gains pointer handling by recognizing nbtypes.CPointer and appending * to base type strings.
CI style check
ci/check_style.sh
Ensures g++ is available by adding apt-get update and apt-get install -y g++ before pre-commit installation.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Deductor as "deduce_templated_overloads"
    participant IntentPlanner as "Intent Planner"
    participant TypeDeducer as "Type Pattern Deducer"
    participant Specializer as "Template Specializer"
    participant Validator as "Placeholder Validator"

    Caller->>Deductor: call(overloads, args, overrides, debug)
    activate Deductor

    alt overrides provided
        Deductor->>IntentPlanner: compute visible params + intents
        IntentPlanner-->>Deductor: visible params, intent errors
    else no overrides
        Deductor->>Deductor: use all function parameters as visible
    end

    loop per overload
        Deductor->>Deductor: check visible-parameter arity vs args
        alt arity matches
            loop per visible parameter
                Deductor->>TypeDeducer: match param pattern with arg type
                TypeDeducer-->>Deductor: mapping fragment or conflict
            end
            Deductor->>Specializer: apply combined mappings to template
            Specializer-->>Deductor: specialized function/struct-method
            Deductor->>Validator: check unresolved placeholders
            Validator-->>Deductor: accept or discard specialization
        else arity mismatch
            Deductor-->>Deductor: skip overload
        end
    end

    Deductor-->>Caller: return specialized templates + collected errors
    deactivate Deductor
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🐰

I hopped through types and patterns bright,
I nudged placeholders into the light,
Args and templates kissed in tune,
Functions grew concrete by moon,
A grateful twitch — specialization's right!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately and concisely describes the main change: adding a new template argument deduction module and accompanying unit tests, which aligns perfectly with the primary objectives and file additions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 195-262: The except block in deduce_templated_overloads currently
catches Exception when calling compute_intent_plan; narrow this to only the
documented exceptions by changing the handler to except (ValueError, TypeError)
as exc so only intent-related ValueError/TypeError cases are appended to
intent_errors; locate the try/except around compute_intent_plan (call site:
compute_intent_plan, variables plan/intent_errors) and replace the broad
Exception catch with the tuple of specific exception types.
- Around line 64-69: The helper _numba_arg_to_cxx_type currently catches all
Exceptions when calling to_c_type_str(arg), which can hide programming errors;
change the broad except to catch only ValueError (the error to_c_type_str raises
for unknown Numba types) so that other exceptions (e.g.,
AttributeError/TypeError) still surface; keep the behavior of normalizing the
arg via _normalize_numba_arg_type and returning None on ValueError after
attempting to normalize the C++ type with _normalize_cxx_type_str.
- Around line 72-77: _param_type_matches_arg calls to_numba_type(cxx_type) which
can raise KeyError for unknown C++ types; wrap the to_numba_type call in a
try/except and handle the KeyError by treating unknown mappings the same way as
nbtypes.undefined (i.e., return True or otherwise mark as compatible). Update
the function _param_type_matches_arg to catch KeyError from to_numba_type, log
or comment if desired, and then fallback to the existing branch that returns
True for undefined types; keep using _normalize_numba_arg_type(arg) for the
final comparison.
- Around line 80-105: _deduce_from_type_pattern currently records each
placeholder only once in order while pattern replacement creates a capture group
for every occurrence; update the logic so order records the placeholder for each
occurrence (i.e., append ph each time you replace it) so that match.groups() can
be zipped to every placeholder occurrence and conflicting values are checked
against previously deduced values. Concretely, iterate through cxx_type
occurrences (or perform the replacements with a callback) and for each found
placeholder add a corresponding r"(.*?)" to the regex and append that
placeholder to order; then keep the existing match, strip, emptiness check and
conflict detection against deduced in the same way.

Comment on lines +72 to +77
def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
"""Best-effort compatibility check for non-templated parameters."""
nb_expected = to_numba_type(cxx_type)
if nb_expected is nbtypes.undefined:
return True
return nb_expected == _normalize_numba_arg_type(arg)
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing exception handling for unknown C++ types.

to_numba_type can raise KeyError when cxx_type is not in CTYPE_MAPS. This could cause the deduction process to fail unexpectedly for unsupported types. Consider handling the exception gracefully, similar to _numba_arg_to_cxx_type.

🐛 Proposed fix
 def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
     """Best-effort compatibility check for non-templated parameters."""
-    nb_expected = to_numba_type(cxx_type)
+    try:
+        nb_expected = to_numba_type(cxx_type)
+    except KeyError:
+        # Unknown C++ type; assume compatible as a best-effort fallback.
+        return True
     if nb_expected is nbtypes.undefined:
         return True
     return nb_expected == _normalize_numba_arg_type(arg)
🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 72 - 77,
_param_type_matches_arg calls to_numba_type(cxx_type) which can raise KeyError
for unknown C++ types; wrap the to_numba_type call in a try/except and handle
the KeyError by treating unknown mappings the same way as nbtypes.undefined
(i.e., return True or otherwise mark as compatible). Update the function
_param_type_matches_arg to catch KeyError from to_numba_type, log or comment if
desired, and then fallback to the existing branch that returns True for
undefined types; keep using _normalize_numba_arg_type(arg) for the final
comparison.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a local branch, I updated to_numba_type to handle types not in Numbast's type system to return a numba.undefined type. That's how it will be handled here. I will cherry-pick that change to this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isVoid, thanks for the clarification! Returning numba.undefined from to_numba_type for unknown C++ types is a cleaner solution, and it will work perfectly with the existing logic in _param_type_matches_arg (lines 75-76 already handle undefined by treating it as compatible).

Looking forward to seeing the cherry-picked change in this PR!


✏️ Learnings added
Learnt from: isVoid
Repo: NVIDIA/numbast PR: 283
File: numbast/src/numbast/deduction.py:78-83
Timestamp: 2026-01-26T22:21:53.162Z
Learning: In `numbast/src/numbast/types.py`, the `to_numba_type` function returns `numba.undefined` for C++ types not in Numbast's type system, rather than raising a KeyError.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: isVoid
Repo: NVIDIA/numbast PR: 278
File: numbast/src/numbast/function.py:16-17
Timestamp: 2026-01-17T18:17:18.274Z
Learning: In `numbast/src/numbast/function.py`, operator overloads have different argument intent expectations and require special handling separate from regular functions when implementing arg_intent support.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI script currently calls "sudo apt-get update" and "sudo
apt-get install -y g++" which fails when sudo isn't present; update the install
block to check for root (use [ "$(id -u)" -eq 0 ]), and if running as root run
"apt-get update" and "apt-get install -y g++", otherwise fall back to "sudo
apt-get update" and "sudo apt-get install -y g++" so the commands succeed both
in root CI containers and non-root environments.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI install step can hang on interactive prompts; update
the apt-get invocation in ci/check_style.sh by setting the
DEBIAN_FRONTEND=noninteractive environment variable when running apt-get install
so it runs non-interactively (e.g., prefix the apt-get install -y g++ command
with DEBIAN_FRONTEND=noninteractive) while keeping apt-get update and the -y
flag intact; this ensures apt-get update and apt-get install -y g++ complete
reliably in the CI container.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 184-198: The substring checks in _unresolved_placeholders (which
inspects func.return_type.unqualified_non_ref_type_name and each
param.type_.unqualified_non_ref_type_name against placeholder_names) can produce
false positives (e.g., "T" matching inside "MyThing"); update the detection to
use word-boundary/token checks instead of naive "p in type_name" — for example,
compile placeholder_names into regexes that use \b boundaries or split the type
name into identifier tokens and check exact equality against each placeholder
token so only whole-placeholder matches mark unresolved.
- Around line 114-117: The _replace_placeholders function can corrupt
overlapping keys (e.g., "T" and "TT"); update it to perform replacements in a
safe order by iterating replacements keys sorted by length descending (or build
a single regex from escaped keys and replace matches using the replacements
dict) so longer placeholders are replaced before shorter ones; modify the loop
in _replace_placeholders to use sorted(replacements.keys(), key=len,
reverse=True) or a regex-based approach and then apply replacements accordingly.

@isVoid isVoid enabled auto-merge (squash) January 26, 2026 23:05
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
numbast/src/numbast/types.py (1)

106-129: Propagate unknown base types through pointer/array wrappers.

to_numba_type now returns nbtypes.undefined for unknown C++ types, but pointer/array branches still wrap unknown bases in CPointer/UniTuple. That makes _param_type_matches_arg treat them as concrete mismatches instead of “unknown/compatible”, so overloads with unknown pointer/array types may be skipped incorrectly.

🛠️ Proposed fix
     if ty.endswith("*"):
         base_ty = ty.rstrip("*").rstrip(" ")
-        return nbtypes.CPointer(to_numba_type(base_ty))
+        base_nb = to_numba_type(base_ty)
+        return (
+            nbtypes.undefined
+            if base_nb is nbtypes.undefined
+            else nbtypes.CPointer(base_nb)
+        )

     # A pointer to an array type, collapsed as a simple array pointer.
     if "(*)[" in ty:
         base_ty = ty.split(" (")[0]
-        return nbtypes.CPointer(to_numba_type(base_ty))
+        base_nb = to_numba_type(base_ty)
+        return (
+            nbtypes.undefined
+            if base_nb is nbtypes.undefined
+            else nbtypes.CPointer(base_nb)
+        )

     # Support for array type is still incomplete in ast_canopy,
     # doing manual parsing for array type here.
     arr_type_pat = r"(.*)\[(\d+)\]"
     is_array_type = re.match(arr_type_pat, ty)
     if is_array_type:
         base_ty, size = is_array_type.groups()
-        return nbtypes.UniTuple(to_numba_type(base_ty), int(size))
+        base_nb = to_numba_type(base_ty)
+        if base_nb is nbtypes.undefined:
+            return nbtypes.undefined
+        return nbtypes.UniTuple(base_nb, int(size))

Based on learnings, keep unknown-type fallbacks consistent across wrappers.

🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 86-153: The placeholder regex currently matches substrings inside
identifiers causing false positives; update the construction of
placeholder_regex in _deduce_from_type_pattern to enforce C++ identifier
boundaries (e.g., use lookarounds such as
(?<![A-Za-z0-9_])(?:PH_ALTERNATION)(?![A-Za-z0-9_])) so placeholders like "T"
only match as whole identifiers, keeping the rest of the logic
(placeholder_patterns, placeholder_regex.finditer, order, pattern_parts, and
subsequent matching) unchanged.

@isVoid isVoid merged commit 9fcf0bb into NVIDIA:main Jan 27, 2026
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant