-
Notifications
You must be signed in to change notification settings - Fork 18
Add template argument deduction module and tests #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughAdds a template-overload deduction engine that maps Numba types to C++-style strings, deduces template parameters from argument types, specializes templated functions and struct methods, adds tests for various deduction scenarios, tweaks pointer type handling, and ensures g++ is installed in CI. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Deductor as "deduce_templated_overloads"
participant IntentPlanner as "Intent Planner"
participant TypeDeducer as "Type Pattern Deducer"
participant Specializer as "Template Specializer"
participant Validator as "Placeholder Validator"
Caller->>Deductor: call(overloads, args, overrides, debug)
activate Deductor
alt overrides provided
Deductor->>IntentPlanner: compute visible params + intents
IntentPlanner-->>Deductor: visible params, intent errors
else no overrides
Deductor->>Deductor: use all function parameters as visible
end
loop per overload
Deductor->>Deductor: check visible-parameter arity vs args
alt arity matches
loop per visible parameter
Deductor->>TypeDeducer: match param pattern with arg type
TypeDeducer-->>Deductor: mapping fragment or conflict
end
Deductor->>Specializer: apply combined mappings to template
Specializer-->>Deductor: specialized function/struct-method
Deductor->>Validator: check unresolved placeholders
Validator-->>Deductor: accept or discard specialization
else arity mismatch
Deductor-->>Deductor: skip overload
end
end
Deductor-->>Caller: return specialized templates + collected errors
deactivate Deductor
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 195-262: The except block in deduce_templated_overloads currently
catches Exception when calling compute_intent_plan; narrow this to only the
documented exceptions by changing the handler to except (ValueError, TypeError)
as exc so only intent-related ValueError/TypeError cases are appended to
intent_errors; locate the try/except around compute_intent_plan (call site:
compute_intent_plan, variables plan/intent_errors) and replace the broad
Exception catch with the tuple of specific exception types.
- Around line 64-69: The helper _numba_arg_to_cxx_type currently catches all
Exceptions when calling to_c_type_str(arg), which can hide programming errors;
change the broad except to catch only ValueError (the error to_c_type_str raises
for unknown Numba types) so that other exceptions (e.g.,
AttributeError/TypeError) still surface; keep the behavior of normalizing the
arg via _normalize_numba_arg_type and returning None on ValueError after
attempting to normalize the C++ type with _normalize_cxx_type_str.
- Around line 72-77: _param_type_matches_arg calls to_numba_type(cxx_type) which
can raise KeyError for unknown C++ types; wrap the to_numba_type call in a
try/except and handle the KeyError by treating unknown mappings the same way as
nbtypes.undefined (i.e., return True or otherwise mark as compatible). Update
the function _param_type_matches_arg to catch KeyError from to_numba_type, log
or comment if desired, and then fallback to the existing branch that returns
True for undefined types; keep using _normalize_numba_arg_type(arg) for the
final comparison.
- Around line 80-105: _deduce_from_type_pattern currently records each
placeholder only once in order while pattern replacement creates a capture group
for every occurrence; update the logic so order records the placeholder for each
occurrence (i.e., append ph each time you replace it) so that match.groups() can
be zipped to every placeholder occurrence and conflicting values are checked
against previously deduced values. Concretely, iterate through cxx_type
occurrences (or perform the replacements with a callback) and for each found
placeholder add a corresponding r"(.*?)" to the regex and append that
placeholder to order; then keep the existing match, strip, emptiness check and
conflict detection against deduced in the same way.
| def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool: | ||
| """Best-effort compatibility check for non-templated parameters.""" | ||
| nb_expected = to_numba_type(cxx_type) | ||
| if nb_expected is nbtypes.undefined: | ||
| return True | ||
| return nb_expected == _normalize_numba_arg_type(arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing exception handling for unknown C++ types.
to_numba_type can raise KeyError when cxx_type is not in CTYPE_MAPS. This could cause the deduction process to fail unexpectedly for unsupported types. Consider handling the exception gracefully, similar to _numba_arg_to_cxx_type.
🐛 Proposed fix
def _param_type_matches_arg(cxx_type: str, arg: nbtypes.Type) -> bool:
"""Best-effort compatibility check for non-templated parameters."""
- nb_expected = to_numba_type(cxx_type)
+ try:
+ nb_expected = to_numba_type(cxx_type)
+ except KeyError:
+ # Unknown C++ type; assume compatible as a best-effort fallback.
+ return True
if nb_expected is nbtypes.undefined:
return True
return nb_expected == _normalize_numba_arg_type(arg)🤖 Prompt for AI Agents
In `@numbast/src/numbast/deduction.py` around lines 72 - 77,
_param_type_matches_arg calls to_numba_type(cxx_type) which can raise KeyError
for unknown C++ types; wrap the to_numba_type call in a try/except and handle
the KeyError by treating unknown mappings the same way as nbtypes.undefined
(i.e., return True or otherwise mark as compatible). Update the function
_param_type_matches_arg to catch KeyError from to_numba_type, log or comment if
desired, and then fallback to the existing branch that returns True for
undefined types; keep using _normalize_numba_arg_type(arg) for the final
comparison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a local branch, I updated to_numba_type to handle types not in Numbast's type system to return a numba.undefined type. That's how it will be handled here. I will cherry-pick that change to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@isVoid, thanks for the clarification! Returning numba.undefined from to_numba_type for unknown C++ types is a cleaner solution, and it will work perfectly with the existing logic in _param_type_matches_arg (lines 75-76 already handle undefined by treating it as compatible).
Looking forward to seeing the cherry-picked change in this PR!
✏️ Learnings added
Learnt from: isVoid
Repo: NVIDIA/numbast PR: 283
File: numbast/src/numbast/deduction.py:78-83
Timestamp: 2026-01-26T22:21:53.162Z
Learning: In `numbast/src/numbast/types.py`, the `to_numba_type` function returns `numba.undefined` for C++ types not in Numbast's type system, rather than raising a KeyError.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: isVoid
Repo: NVIDIA/numbast PR: 278
File: numbast/src/numbast/function.py:16-17
Timestamp: 2026-01-17T18:17:18.274Z
Learning: In `numbast/src/numbast/function.py`, operator overloads have different argument intent expectations and require special handling separate from regular functions when implementing arg_intent support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI script currently calls "sudo apt-get update" and "sudo
apt-get install -y g++" which fails when sudo isn't present; update the install
block to check for root (use [ "$(id -u)" -eq 0 ]), and if running as root run
"apt-get update" and "apt-get install -y g++", otherwise fall back to "sudo
apt-get update" and "sudo apt-get install -y g++" so the commands succeed both
in root CI containers and non-root environments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@ci/check_style.sh`:
- Around line 7-9: The CI install step can hang on interactive prompts; update
the apt-get invocation in ci/check_style.sh by setting the
DEBIAN_FRONTEND=noninteractive environment variable when running apt-get install
so it runs non-interactively (e.g., prefix the apt-get install -y g++ command
with DEBIAN_FRONTEND=noninteractive) while keeping apt-get update and the -y
flag intact; this ensures apt-get update and apt-get install -y g++ complete
reliably in the CI container.
…lecting overloads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 184-198: The substring checks in _unresolved_placeholders (which
inspects func.return_type.unqualified_non_ref_type_name and each
param.type_.unqualified_non_ref_type_name against placeholder_names) can produce
false positives (e.g., "T" matching inside "MyThing"); update the detection to
use word-boundary/token checks instead of naive "p in type_name" — for example,
compile placeholder_names into regexes that use \b boundaries or split the type
name into identifier tokens and check exact equality against each placeholder
token so only whole-placeholder matches mark unresolved.
- Around line 114-117: The _replace_placeholders function can corrupt
overlapping keys (e.g., "T" and "TT"); update it to perform replacements in a
safe order by iterating replacements keys sorted by length descending (or build
a single regex from escaped keys and replace matches using the replacements
dict) so longer placeholders are replaced before shorter ones; modify the loop
in _replace_placeholders to use sorted(replacements.keys(), key=len,
reverse=True) or a regex-based approach and then apply replacements accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
numbast/src/numbast/types.py (1)
106-129: Propagate unknown base types through pointer/array wrappers.
to_numba_typenow returnsnbtypes.undefinedfor unknown C++ types, but pointer/array branches still wrap unknown bases inCPointer/UniTuple. That makes_param_type_matches_argtreat them as concrete mismatches instead of “unknown/compatible”, so overloads with unknown pointer/array types may be skipped incorrectly.🛠️ Proposed fix
if ty.endswith("*"): base_ty = ty.rstrip("*").rstrip(" ") - return nbtypes.CPointer(to_numba_type(base_ty)) + base_nb = to_numba_type(base_ty) + return ( + nbtypes.undefined + if base_nb is nbtypes.undefined + else nbtypes.CPointer(base_nb) + ) # A pointer to an array type, collapsed as a simple array pointer. if "(*)[" in ty: base_ty = ty.split(" (")[0] - return nbtypes.CPointer(to_numba_type(base_ty)) + base_nb = to_numba_type(base_ty) + return ( + nbtypes.undefined + if base_nb is nbtypes.undefined + else nbtypes.CPointer(base_nb) + ) # Support for array type is still incomplete in ast_canopy, # doing manual parsing for array type here. arr_type_pat = r"(.*)\[(\d+)\]" is_array_type = re.match(arr_type_pat, ty) if is_array_type: base_ty, size = is_array_type.groups() - return nbtypes.UniTuple(to_numba_type(base_ty), int(size)) + base_nb = to_numba_type(base_ty) + if base_nb is nbtypes.undefined: + return nbtypes.undefined + return nbtypes.UniTuple(base_nb, int(size))Based on learnings, keep unknown-type fallbacks consistent across wrappers.
🤖 Fix all issues with AI agents
In `@numbast/src/numbast/deduction.py`:
- Around line 86-153: The placeholder regex currently matches substrings inside
identifiers causing false positives; update the construction of
placeholder_regex in _deduce_from_type_pattern to enforce C++ identifier
boundaries (e.g., use lookarounds such as
(?<![A-Za-z0-9_])(?:PH_ALTERNATION)(?![A-Za-z0-9_])) so placeholders like "T"
only match as whole identifiers, keeping the rest of the logic
(placeholder_patterns, placeholder_regex.finditer, order, pattern_parts, and
subsequent matching) unchanged.
This adds a template argument deduction module and unit tests, and it incorporates the recently added argument intent concept from Numbast PR #278 so visible arity and pointer passing stay consistent across templated overloads.
At a high level, the module accepts a list of
ast_canopy-parsed function templates plus optional intent overrides, and returnsFunctionTemplateinstances with deduced argument types for the original parameter names.Tested cases:
template <typename T> __device__ T add(T a, T b);andtemplate <typename T> __device__ T add(T a, T b, T c);add(int, int)picks the two-arg overload and deducesT=int.template <typename T> __device__ T add(T a, T b);add(int, float)yields no specialization becauseTconflicts.template <typename T> __device__ T add_int(int a, T b);add_int(int, float)specializes,add_int(float, float)does not.template <typename T> __device__ T return_only();return_only<T>()cannot deduceTwith no args, so it is skipped.template <typename T> __device__ void store_ref(T &out, T value);store_ref(CPointer(int), int)withoverrides={"out": "out_ptr"}or{"out": "inout_ptr"}deducesT=int;store_ref(int)withoverrides={"out": "out_return"}also deducesT=intvia hidden out param.template <typename T> __device__ void bad_out(T value);bad_out(value=out_ptr)reports aValueErrorinintent_errors.struct Box { template <typename T> __device__ T mul(T a, T b) const; };Box::mul(float, float)deducesT=float.template <typename T> __device__ T add(T a, T b);add(float32[:], float32[:])yields no specialization.This currently is a standalone module that's not wired into any existing binding generation so that it can be tested and reasoned independently.
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.