-
Notifications
You must be signed in to change notification settings - Fork 18
Add intent-based ref handling and out-return support for bindings #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced `ArgIntent` enum to manage parameter exposure in Numba. - Implemented `IntentPlan` for computing per-parameter intent plans. - Updated function and struct binding methods to support argument intents. - Added utility functions for converting AST types to Numba types with intent considerations. - Enhanced static binding generation to include argument intent overrides. - Added tests for out-parameter functions to validate new functionality. This commit lays the groundwork for more flexible and explicit handling of function argument intents in Numba, improving the integration with C++ parameter types.
📝 WalkthroughWalkthroughThis PR implements an argument-intent system (ArgIntent/IntentPlan) and threads intent-aware handling through types, call conventions, renderers, generator config, and tests to support in / inout_ptr / out_ptr / out_return semantics across C++→Numba CUDA bindings. Also adds a codespell ignore for "inout". Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Binder as Binding entrypoint
participant Intent as compute_intent_plan
participant Renderer as Static/Dynamic Renderer
participant ArgConv as prepare_ir_types
participant CallConv as FunctionCallConv
participant Shim as Lowering/Shim
User->>Binder: bind_cxx_function(..., arg_intent={...})
Binder->>Intent: compute_intent_plan(params, param_types, overrides)
Intent-->>Binder: IntentPlan(intents, visible_indices, out_return_indices, pass_ptr_mask)
Binder->>Renderer: render signature with IntentPlan
Renderer->>ArgConv: prepare_ir_types(context, argtys, pass_ptr_mask)
ArgConv-->>Renderer: IR arg types (selective pointer wrapping)
Renderer->>CallConv: instantiate with intent_plan, out_return_types, cxx_return_type
CallConv->>Shim: generate arg setup and return gathering (intent-aware)
Shim-->>User: produced binding / kernel-ready shim
sequenceDiagram
participant Kernel
participant Setup
participant CxxFunc
participant Collector
participant KernelRet as Kernel
Kernel->>Setup: call shim with args + intent_plan
Setup->>Setup: allocate/use pointers per pass_ptr_mask
Setup->>CxxFunc: invoke C++ function with selective pointers/values
CxxFunc->>CxxFunc: mutate reference params / compute return
CxxFunc-->>Collector: modified refs + return
Collector->>KernelRet: assemble out_returns (tuple or single)
KernelRet-->>Kernel: deliver final result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
numbast/src/numbast/args.py (1)
9-33: Update docstring to reflect pass_ptr_mask behavior.The docstring still implies unconditional pointer wrapping and doesn’t mention the new parameter, which makes the public API misleading.
📝 Suggested docstring update
@@ - Parameters + Parameters @@ - argtys : list[ir.Type] - List of LLVM IR types representing function arguments. + argtys : list[ir.Type] + List of LLVM IR types representing function arguments. + pass_ptr_mask : list[bool] | None, optional + When True for an argument and its value type is already a pointer, + pass it through without additional wrapping. Defaults to all False. @@ - list[ir.Type] - List of pointer types wrapping the value types of each argument. + list[ir.Type] + List of pointer types for the call ABI; pointer-typed arguments may be + passed through unchanged when pass_ptr_mask is set.numbast/src/numbast/static/tests/conftest.py (1)
26-46: Preserve caller-provided empty intent dict.Using
function_argument_intents or {}replaces a caller-supplied empty dict, which can break reference expectations if the caller relies on identity or later mutation.🔧 Suggested fix
- cfg.function_argument_intents = function_argument_intents or {} + if function_argument_intents is None: + function_argument_intents = {} + cfg.function_argument_intents = function_argument_intentsnumbast/src/numbast/static/struct.py (1)
598-805: Generated bindings referenceArgIntent/IntentPlanwithout importing them.When overrides are provided, the rendered code emits
ArgIntent.*andIntentPlan(...), but no import is added toself.Imports. This will raiseNameErrorin the generated bindings. Add the import when overrides are active.🛠️ Add required imports to generated bindings
- if overrides is None: + if overrides is None: # Cache Numba param and return types (as strings) self._nb_param_types = [ to_numba_arg_type_str(t) for t in self._method_decl.param_types ] @@ else: + self.Imports.add("from numbast.intent import ArgIntent, IntentPlan") method_plan = compute_intent_plan( params=self._method_decl.params, param_types=self._method_decl.param_types, overrides=overrides, allow_out_return=True, )numbast/src/numbast/static/function.py (2)
143-255: Rendered intent plans needArgIntent/IntentPlanimports.When overrides are set, the generated code embeds
ArgIntent.*andIntentPlan(...), but no import is added toself.Imports, leading toNameErrorat runtime. Add the import when overrides are active.🛠️ Add required imports to generated bindings
- else: + else: + self.Imports.add("from numbast.intent import ArgIntent, IntentPlan") plan = compute_intent_plan( params=self._decl.params, param_types=self._decl.param_types, overrides=overrides, allow_out_return=True, )
273-320: Guard against empty visible-arg lists to avoid invalid syntax.When all parameters are marked
out_return,self._argument_numba_typesbecomes an empty list (line 194), causingself._argument_numba_types_strto be an empty string. The templates then emit@lower(fn, )andsignature(ret, ), which are invalid Python due to trailing commas. Build the parameter list conditionally to add the comma only when parameters exist.🛠️ Suggested fix
- lowering_template = """ -@lower({func_name}, {params}) + lowering_template = """ +@lower({func_name}{params}) def impl(context, builder, sig, args): @@ `@property` def _signature_cases(self): """The python string that declares the signature of this function.""" return_type_name = str(self._return_numba_type_str) - param_types_str = ", ".join(str(t) for t in self._argument_numba_types) - return self.signature_template.format( - return_type=return_type_name, param_types=param_types_str - ) + if not self._argument_numba_types: + return f"signature({return_type_name})" + param_types_str = ", ".join(str(t) for t in self._argument_numba_types) + return self.signature_template.format( + return_type=return_type_name, param_types=param_types_str + ) @@ - self._lowering_rendered = self.lowering_template.format( - func_name=self.func_name_python, - params=self._argument_numba_types_str, + params = ( + f", {self._argument_numba_types_str}" + if self._argument_numba_types_str + else "" + ) + self._lowering_rendered = self.lowering_template.format( + func_name=self.func_name_python, + params=params, mangled_name=self._decl.mangled_name, return_type=self._return_numba_type_str, use_cooperative=use_cooperative, arg_is_ref=self._arg_is_ref, intent_plan=self._intent_plan_rendered, out_return_types=self._out_return_types_rendered, cxx_return_type=self._cxx_return_type_rendered, )
🤖 Fix all issues with AI agents
In `@numbast/src/numbast/function.py`:
- Around line 16-17: bind_cxx_operator_overload_function currently accepts an
arg_intent parameter but never applies or validates it, silently ignoring user
overrides; either apply intent planning here or reject overrides explicitly. Fix
by invoking compute_intent_plan with the derived argument types (use
to_numba_arg_type/to_numba_type results used in
bind_cxx_operator_overload_function) and pass the resulting intent plan into the
Function/ABI construction path (same way non-operator/static paths do), or if
you prefer a simple guard, raise a clear error when arg_intent is provided
(e.g., in bind_cxx_operator_overload_function) to refuse unsupported intent
overrides for operator overloads. Ensure references:
bind_cxx_operator_overload_function, compute_intent_plan,
to_numba_arg_type/to_numba_type, and the Function/ABI creation logic are updated
accordingly.
In `@numbast/src/numbast/intent_defs.py`:
- Around line 30-39: Add a __post_init__ method to the IntentPlan dataclass that
validates length invariants: ensure len(self.pass_ptr_mask) ==
len(self.visible_param_indices) and optionally that len(self.intents) >=
max(self.visible_param_indices, default=-1)+1 and that all indices in
self.visible_param_indices and self.out_return_indices are within range of
intents; raise a ValueError with a clear message if any check fails. Implement
this in the IntentPlan class so the invariant between pass_ptr_mask and
visible_param_indices is enforced at construction.
In `@numbast/src/numbast/intent.py`:
- Around line 11-27: The current parse logic is attached via
setattr(_parse_arg_intent) which is unconventional and the "return" alias maps
to out_return and may be confusing; refactor by moving the parsing code into
ArgIntent as a proper `@classmethod` named parse (replace _parse_arg_intent), and
either remove the ambiguous "return" alias from the v2 == checks or document it
in the ArgIntent docstring/enum comment so callers know it maps to out_return;
ensure the method keeps the same alias checks for "in",
"inout_ptr"/"inout"/"mutate"/"mutative", "out_ptr"/"out", and
"out_return"/"outret" but no longer relies on setattr.
- Around line 69-103: Extract the duplicate dict-extraction into a small helper
(e.g. _extract_intent_value(raw)) used in both places to replace the chained
.get() calls; helper should: if not isinstance(raw, dict) return raw as-is,
otherwise check the keys in order ["intent","Intent","INTENT"] and if a key
exists in the dict return its value (even if that value is None) so we don't
silently fall through, and return None only if none of the keys exist; then call
_parse_arg_intent(ArgIntent, _extract_intent_value(raw)) in both the
index-handling loop and the name-handling loop (referencing variables/methods:
overrides, raw, _parse_arg_intent, ArgIntent, normalized, params, name_to_idx).
In `@numbast/src/numbast/static/callconv.py`:
- Around line 19-21: The current import-patching logic that builds
_CALLCONV_SRC_PATCHED by string-replacing imports from _CALLCONV_SRC is fragile;
instead modify the source template _CALLCONV_SRC to include explicit commented
markers (e.g., "# BEGIN GENERATED IMPORTS" and "# END GENERATED IMPORTS") around
the import block and then replace the whole marked section when generating
_CALLCONV_SRC_PATCHED; update the code that constructs _CALLCONV_SRC_PATCHED to
locate these markers (search for the start and end marker strings) and remove or
replace the enclosed text rather than doing simple .replace("from numbast.args
import prepare_ir_types","")/.replace("from numbast.intent import
IntentPlan",""), keeping the marker names consistent with the variables
_CALLCONV_SRC and _CALLCONV_SRC_PATCHED so the change is localized and robust.
In `@numbast/src/numbast/tools/static_binding_generator.py`:
- Around line 104-109: The factory method from_params is missing the optional
function_argument_intents parameter even though the class exposes that attribute
and the YAML config supports it; update from_params to accept an optional
function_argument_intents: dict[str, dict[str|int, str|dict]] parameter and
propagate it into the created instance (either by passing it into the
constructor call used inside from_params or by setting
instance.function_argument_intents before returning), ensuring callers and tests
no longer need to set the attribute post-construction; reference the from_params
function and the function_argument_intents attribute when making the change.
In `@numbast/tests/test_function.py`:
- Around line 119-132: The test duplicates parsing and binding logic; extract
the inline setup into a reusable pytest fixture (e.g.,
_sample_out_ptr_functions) that performs parse_declarations_from_source(...),
creates MemoryShimWriter(f'#include "{p}"'), calls bind_cxx_functions(...) with
arg_intent mapping for "add_out"->"out_ptr" and "add_in_ref"->"in", and returns
(func_bindings, shim_writer); then update the test to use this fixture instead
of repeating the same code and remove the duplicated block now covered by the
fixture.
- Around line 49-56: The two lookups using next(...) for add_one_inplace and
set_42 on func_bindings can raise StopIteration with an unclear traceback;
introduce a small helper (e.g., find_binding or assert_find_binding) that
iterates func_bindings, returns the function when
getattr(f,"__name__",None)==name, and raises a clear AssertionError listing
available binding names if not found, then use that helper to assign
add_one_inplace and set_42 (or alternatively use next(..., None) followed by an
explicit assert with a descriptive message).
| @dataclass(frozen=True) | ||
| class IntentPlan: | ||
| """ | ||
| Normalized intent plan for a callable with N original parameters. | ||
| """ | ||
|
|
||
| intents: tuple[ArgIntent, ...] # length N | ||
| visible_param_indices: tuple[int, ...] # subset of [0..N) | ||
| out_return_indices: tuple[int, ...] # subset of [0..N) | ||
| pass_ptr_mask: tuple[bool, ...] # aligned with visible params only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Consider adding __post_init__ validation for length invariants.
The pass_ptr_mask is documented as "aligned with visible params only," but the dataclass doesn't enforce that len(pass_ptr_mask) == len(visible_param_indices). Since compute_intent_plan constructs these correctly, this is low-risk, but adding a __post_init__ check would prevent misuse and make the invariant explicit.
♻️ Optional validation
`@dataclass`(frozen=True)
class IntentPlan:
"""
Normalized intent plan for a callable with N original parameters.
"""
intents: tuple[ArgIntent, ...] # length N
visible_param_indices: tuple[int, ...] # subset of [0..N)
out_return_indices: tuple[int, ...] # subset of [0..N)
pass_ptr_mask: tuple[bool, ...] # aligned with visible params only
+
+ def __post_init__(self):
+ if len(self.pass_ptr_mask) != len(self.visible_param_indices):
+ raise ValueError(
+ f"pass_ptr_mask length ({len(self.pass_ptr_mask)}) must match "
+ f"visible_param_indices length ({len(self.visible_param_indices)})"
+ )🤖 Prompt for AI Agents
In `@numbast/src/numbast/intent_defs.py` around lines 30 - 39, Add a __post_init__
method to the IntentPlan dataclass that validates length invariants: ensure
len(self.pass_ptr_mask) == len(self.visible_param_indices) and optionally that
len(self.intents) >= max(self.visible_param_indices, default=-1)+1 and that all
indices in self.visible_param_indices and self.out_return_indices are within
range of intents; raise a ValueError with a clear message if any check fails.
Implement this in the IntentPlan class so the invariant between pass_ptr_mask
and visible_param_indices is enforced at construction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's good but I'm intentionally making IntentPlan a minimalistic class so that it can be vended as-is to the generated bindings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@isVoid, that makes perfect sense! Keeping IntentPlan minimal is the right call if it needs to be included in generated bindings—adding validation logic would unnecessarily complicate the generated code. Since compute_intent_plan already enforces the invariant during construction, you're covered. 👍
|
Note Docstrings generation - SUCCESS |
Docstrings generation was requested by @isVoid. * #278 (comment) The following files were modified: * `numbast/src/numbast/args.py` * `numbast/src/numbast/callconv.py` * `numbast/src/numbast/class_template.py` * `numbast/src/numbast/function.py` * `numbast/src/numbast/intent.py` * `numbast/src/numbast/static/callconv.py` * `numbast/src/numbast/static/function.py` * `numbast/src/numbast/static/struct.py` * `numbast/src/numbast/static/tests/conftest.py` * `numbast/src/numbast/static/tests/data/src/function_out.cu` * `numbast/src/numbast/static/tests/test_function_static_bindings.py` * `numbast/src/numbast/static/types.py` * `numbast/src/numbast/struct.py` * `numbast/src/numbast/tools/static_binding_generator.py` * `numbast/src/numbast/types.py` * `numbast/tests/test_function.py`
This adds a template argument deduction module and unit tests, and it incorporates the recently added argument intent concept from Numbast PR #278 so visible arity and pointer passing stay consistent across templated overloads. At a high level, the module accepts a list of `ast_canopy`-parsed function templates plus optional intent overrides, and returns `FunctionTemplate` instances with deduced argument types for the original parameter names. Tested cases: - Overload selection by visible arity. - C++: `template <typename T> __device__ T add(T a, T b);` and `template <typename T> __device__ T add(T a, T b, T c);` - Example: `add(int, int)` picks the two-arg overload and deduces `T=int`. - Conflicting placeholder deduction skips an overload. - C++: `template <typename T> __device__ T add(T a, T b);` - Example: `add(int, float)` yields no specialization because `T` conflicts. - Non-templated parameter types must match. - C++: `template <typename T> __device__ T add_int(int a, T b);` - Example: `add_int(int, float)` specializes, `add_int(float, float)` does not. - Return-only placeholders are skipped. - C++: `template <typename T> __device__ T return_only();` - Example: `return_only<T>()` cannot deduce `T` with no args, so it is skipped. - **Intent overrides including out_return.** - C++: `template <typename T> __device__ void store_ref(T &out, T value);` - Examples: `store_ref(CPointer(int), int)` with `overrides={"out": "out_ptr"}` or `{"out": "inout_ptr"}` deduces `T=int`; `store_ref(int)` with `overrides={"out": "out_return"}` also deduces `T=int` via hidden out param. - Invalid overrides surface as errors. - C++: `template <typename T> __device__ void bad_out(T value);` - Example: `bad_out(value=out_ptr)` reports a `ValueError` in `intent_errors`. - Struct method specialization. - C++: `struct Box { template <typename T> __device__ T mul(T a, T b) const; };` - Example: `Box::mul(float, float)` deduces `T=float`. - Unmappable Numba args are skipped. - C++: `template <typename T> __device__ T add(T a, T b);` - Example: `add(float32[:], float32[:])` yields no specialization. This currently is a standalone module that's not wired into any existing binding generation so that it can be tested and reasoned independently. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Template overload deduction with automatic specialization for templated functions and methods. * **Improvements** * Better handling of pointer/reference parameters, placeholder resolution, conflict detection, and optional detailed debug tracing. * **Bug Fixes** * Corrected pointer-type stringification so pointer types are recognized properly. * **Tests** * Comprehensive tests for overload selection, conflicts, overrides, pointer/ref cases, and method specialization. * **Chores** * CI: ensure compiler availability before running style checks. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
In ISO C++, whether an argument is mutated or not is not directly inferable from the signature itself. For example,
both accepts a reference as input, but whether the function modifies the argument is unknown until further analysis into the body is performed. This creates confusions for language bindings on whether the arguments should be passed by reference or value, as other language may dictate the argument passing semantics differently from C++.
Certain compiler provides additional annotation features to denote them. Such as SAL.
In this PR, Numbast introduces an
argument intentoption that allows user to configure argument passing mode on per-function, per-parameter basis. Per argument, the following options are available:Take the following C++ signature as an example:
A typical argument intent setup for Numbast looks like:
This indicates that argument
outis returned as the functions return value in the corresponding binding. And since this function already has a return value, the binding will now return a tuple of ints, with first corresponds to the exit code, and the second corresponds to the result. The Python binding signature:Alternatively, if intent is set to:
"out": "out_ptr", the signature becomes:Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.