From 4f2ad515569eb0887c75a569630735bcc6062ea0 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Thu, 16 Apr 2026 22:33:35 +0200 Subject: [PATCH 1/8] perf: Port trampoline to rust --- .../pants/engine/internals/native_engine.pyi | 22 +++ .../pants/engine/internals/rule_visitor.py | 12 +- .../pants/engine/internals/scheduler_test.py | 3 - .../pants/engine/internals/selectors.py | 118 +++++++-------- src/python/pants/engine/rules.py | 31 ++-- src/rust/engine/src/externs/mod.rs | 134 ++++++++++++++++-- src/rust/engine/src/nodes/task.rs | 13 +- 7 files changed, 231 insertions(+), 102 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index f9623e05826..0218b6c07fb 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -1193,6 +1193,28 @@ class PyGeneratorResponseCall: input_arg1: _Input | None = None, ) -> None: ... +class RuleCallTrampoline: + """ + The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so + each invocation constructs the already-awaitable `Call` directly. + `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. + """ + + rule_id: str + output_type: type + rule: Any + __wrapped__: Callable[..., Any] + + def __init__( + self, + rule_id: str, + output_type: type, + call_cls: type, + wrapped: Callable[..., Any], + rule: Any, + ) -> None: ... + def __call__(self, *args: Any, __implicitly: Sequence[Any] = (), **kwargs: Any) -> Any: ... + # ------------------------------------------------------------------------------ # (uncategorized) # ------------------------------------------------------------------------------ diff --git a/src/python/pants/engine/internals/rule_visitor.py b/src/python/pants/engine/internals/rule_visitor.py index 2d00f198019..346061afe12 100644 --- a/src/python/pants/engine/internals/rule_visitor.py +++ b/src/python/pants/engine/internals/rule_visitor.py @@ -16,6 +16,7 @@ import typing_extensions from pants.base.exceptions import RuleTypeError +from pants.engine.internals.native_engine import RuleCallTrampoline from pants.engine.internals.selectors import ( AwaitableConstraints, concurrently, @@ -177,6 +178,11 @@ def _lookup_return_type(func: Callable, check: bool = False) -> Any: class _AwaitableCollector(ast.NodeVisitor): def __init__(self, func: Callable): + # `func` may be a RuleCallTrampoline (the return value of an `@rule`-decorated + # function). `inspect.getsource` and friends only know about real Python functions, + # so follow `__wrapped__` to reach the underlying implementation. + if isinstance(func, RuleCallTrampoline): + func = func.__wrapped__ self.func = func source = inspect.getsource(func) or "" beginning_indent = _get_starting_indent(source) @@ -314,9 +320,9 @@ def _get_byname_awaitable( def visit_Call(self, call_node: ast.Call) -> None: func = self._lookup(call_node.func) if func is not None: - if (inspect.isfunction(func) or isinstance(func, RuleDescriptor)) and ( - rule_id := getattr(func, "rule_id", None) - ) is not None: + if ( + inspect.isfunction(func) or isinstance(func, (RuleDescriptor, RuleCallTrampoline)) + ) and (rule_id := getattr(func, "rule_id", None)) is not None: # Is a direct `@rule` call. self.awaitables.append(self._get_byname_awaitable(rule_id, func, call_node)) elif inspect.iscoroutinefunction(func): diff --git a/src/python/pants/engine/internals/scheduler_test.py b/src/python/pants/engine/internals/scheduler_test.py index 6f66f31e5d3..262ec77a3f2 100644 --- a/src/python/pants/engine/internals/scheduler_test.py +++ b/src/python/pants/engine/internals/scheduler_test.py @@ -474,9 +474,6 @@ def test_trace_includes_nested_exception_traceback() -> None: File LOCATION-INFO, in catch_and_reraise return await raise_an_exception(SomeInput(outer_input.s)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File LOCATION-INFO, in wrapper - return await call - ^^^^^^^^^^ File LOCATION-INFO, in __await__ result = yield self ^^^^^^^^^^ diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index 236adcbb3f4..9b1e6e80c4b 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -43,9 +43,11 @@ def __repr__(self) -> str: @dataclass(frozen=True) class _Concurrently: - calls: tuple[Coroutine, ...] + # A call-by-name `@rule` invocation returns a `Call` directly; async helpers return + # `Coroutine`. The engine accepts both in this sequence. + calls: tuple[Coroutine | Call, ...] - def __await__(self) -> Generator[tuple[Coroutine, ...], None, tuple]: + def __await__(self) -> Generator[tuple[Coroutine | Call, ...], None, tuple]: result = yield self.calls return cast(tuple, result) @@ -238,7 +240,7 @@ async def Concurrently( return await _Concurrently(tuple(__arg0)) if ( - isinstance(__arg0, Coroutine) + isinstance(__arg0, (Coroutine, Call)) and __arg1 is None and __arg2 is None and __arg3 is None @@ -253,8 +255,8 @@ async def Concurrently( return await _Concurrently((__arg0,)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) and __arg2 is None and __arg3 is None and __arg4 is None @@ -268,9 +270,9 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) and __arg3 is None and __arg4 is None and __arg5 is None @@ -283,10 +285,10 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1, __arg2)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) and __arg4 is None and __arg5 is None and __arg6 is None @@ -298,11 +300,11 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1, __arg2, __arg3)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) + and isinstance(__arg4, (Coroutine, Call)) and __arg5 is None and __arg6 is None and __arg7 is None @@ -313,12 +315,12 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) + and isinstance(__arg4, (Coroutine, Call)) + and isinstance(__arg5, (Coroutine, Call)) and __arg6 is None and __arg7 is None and __arg8 is None @@ -328,13 +330,13 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) + and isinstance(__arg4, (Coroutine, Call)) + and isinstance(__arg5, (Coroutine, Call)) + and isinstance(__arg6, (Coroutine, Call)) and __arg7 is None and __arg8 is None and __arg9 is None @@ -343,14 +345,14 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) - and isinstance(__arg7, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) + and isinstance(__arg4, (Coroutine, Call)) + and isinstance(__arg5, (Coroutine, Call)) + and isinstance(__arg6, (Coroutine, Call)) + and isinstance(__arg7, (Coroutine, Call)) and __arg8 is None and __arg9 is None and not __args @@ -358,15 +360,15 @@ async def Concurrently( return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) - and isinstance(__arg7, Coroutine) - and isinstance(__arg8, Coroutine) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) + and isinstance(__arg4, (Coroutine, Call)) + and isinstance(__arg5, (Coroutine, Call)) + and isinstance(__arg6, (Coroutine, Call)) + and isinstance(__arg7, (Coroutine, Call)) + and isinstance(__arg8, (Coroutine, Call)) and __arg9 is None and not __args ): @@ -375,17 +377,17 @@ async def Concurrently( ) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) - and isinstance(__arg7, Coroutine) - and isinstance(__arg8, Coroutine) - and isinstance(__arg9, Coroutine) - and all(isinstance(arg, Coroutine) for arg in __args) + isinstance(__arg0, (Coroutine, Call)) + and isinstance(__arg1, (Coroutine, Call)) + and isinstance(__arg2, (Coroutine, Call)) + and isinstance(__arg3, (Coroutine, Call)) + and isinstance(__arg4, (Coroutine, Call)) + and isinstance(__arg5, (Coroutine, Call)) + and isinstance(__arg6, (Coroutine, Call)) + and isinstance(__arg7, (Coroutine, Call)) + and isinstance(__arg8, (Coroutine, Call)) + and isinstance(__arg9, (Coroutine, Call)) + and all(isinstance(arg, (Coroutine, Call)) for arg in __args) ): return await _Concurrently( ( diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index a17e757a860..1fcc415e644 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -3,10 +3,9 @@ from __future__ import annotations -import functools import inspect import sys -from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence +from collections.abc import Callable, Coroutine, Iterable, Mapping from dataclasses import dataclass from enum import Enum from types import FrameType, ModuleType @@ -25,6 +24,7 @@ from typing_extensions import ParamSpec from pants.engine.engine_aware import SideEffecting +from pants.engine.internals.native_engine import RuleCallTrampoline from pants.engine.internals.rule_visitor import collect_awaitables from pants.engine.internals.selectors import AwaitableConstraints, Call from pants.engine.internals.selectors import concurrently as concurrently # noqa: F401 @@ -57,17 +57,6 @@ class RuleType(Enum): RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], AsyncRuleT] -def _rule_call_trampoline( - rule_id: str, output_type: type[Any], func: Callable[P, R] -) -> Callable[P, R]: - @functools.wraps(func) # type: ignore - async def wrapper(*args, __implicitly: Sequence[Any] = (), **kwargs): - call = Call(rule_id, output_type, args, *__implicitly) - return await call - - return cast(Callable[P, R], wrapper) - - def _make_rule( func_id: str, rule_type: RuleType, @@ -112,13 +101,12 @@ def wrapper(original_func): awaitables = FrozenOrderedSet(collect_awaitables(original_func)) validate_requirements(func_id, parameter_types, awaitables, cacheable) - func = _rule_call_trampoline(canonical_name, return_type, original_func) - # NB: The named definition of the rule ends up wrapped in a trampoline to handle memoization - # and implicit arguments for direct by-name calls. But the `TaskRule` takes a reference to - # the original unwrapped function, which avoids the need for a special protocol when the - # engine invokes a @rule under memoization. - func.rule = TaskRule( + # NB: The named definition of the rule ends up wrapped in a trampoline to handle + # implicit arguments for direct by-name calls. The `TaskRule` takes a reference to + # the original unwrapped function, which avoids the need for a special protocol when + # the engine invokes a @rule under memoization. + task_rule = TaskRule( return_type, FrozenDict(parameter_types), awaitables, @@ -130,7 +118,10 @@ def wrapper(original_func): cacheable=cacheable, polymorphic=polymorphic, ) - return func + return cast( + Callable[P, R], + RuleCallTrampoline(canonical_name, return_type, Call, original_func, task_rule), + ) return wrapper diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 51e4e352685..dc1fb246c20 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -10,6 +10,7 @@ use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; use pyo3::FromPyObject; use pyo3::exceptions::{PyException, PyStopIteration, PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; use pyo3::sync::{MutexExt, RwLockExt}; use pyo3::types::{PyBool, PyBytes, PyDict, PySequence, PyString, PyTuple, PyType}; use pyo3::{create_exception, import_exception, intern}; @@ -50,6 +51,7 @@ pub fn register(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add("EngineError", py.get_type::())?; m.add("IntrinsicError", py.get_type::())?; @@ -354,15 +356,18 @@ pub(crate) fn generator_send( } else if let Ok(call) = response.extract::>() { Ok(GeneratorResponse::NativeCall(call.take(py)?)) } else if let Ok(get_multi) = response.cast::() { - // Was an `All` or `concurrently`. - let generators = get_multi + // Was an `All` or `concurrently`. Each item is either a generator (from an async + // helper) or a direct `Call` (from a call-by-name `@rule` invocation). + let items = get_multi .try_iter()? - .map(|generator| { - let generator = generator?; - // TODO: Find a better way to check whether something is a coroutine... this seems - // unnecessarily awkward. - if generator.is_instance(&generator_type.as_py_type(py))? { - Ok(Value::new(generator.unbind())) + .map(|item| { + let item = item?; + if item.is_instance(&generator_type.as_py_type(py))? { + Ok(AllItem::Generator(Value::new(item.unbind()))) + } else if let Ok(call) = item.extract::>() { + call.take(py) + .map(AllItem::Call) + .map_err(PyValueError::new_err) } else { Err(PyValueError::new_err(format!( "Expected an `All` or `concurrently` to receive calls to rules, \ @@ -371,7 +376,7 @@ pub(crate) fn generator_send( } }) .collect::, _>>()?; - Ok(GeneratorResponse::All(generators)) + Ok(GeneratorResponse::All(items)) } else { Err(PyValueError::new_err(format!( "Async @rule error. Expected a rule call, but got: {response}" @@ -566,6 +571,101 @@ impl PyGeneratorResponseCall { } } +/// The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so +/// each invocation constructs the already-awaitable `Call` directly. +/// `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. +#[pyclass(frozen, module = "pants.engine.internals.native_engine")] +pub struct RuleCallTrampoline { + rule_id: PyBackedStr, + #[pyo3(get)] + output_type: Py, + call_cls: Py, + #[pyo3(get, name = "__wrapped__")] + wrapped: Py, + #[pyo3(get)] + rule: Py, +} + +#[pymethods] +impl RuleCallTrampoline { + #[new] + fn __new__( + rule_id: PyBackedStr, + output_type: Py, + call_cls: Py, + wrapped: Py, + rule: Py, + ) -> Self { + Self { + rule_id, + output_type, + call_cls, + wrapped, + rule, + } + } + + #[getter] + fn rule_id(&self) -> &PyBackedStr { + &self.rule_id + } + + #[pyo3(signature = (*args, __implicitly=None, **_kwargs))] + fn __call__<'py>( + &self, + py: Python<'py>, + args: &Bound<'py, PyTuple>, + __implicitly: Option<&Bound<'py, PyTuple>>, + _kwargs: Option<&Bound<'py, PyDict>>, + ) -> PyResult> { + let call_cls = self.call_cls.bind(py); + let output_type = self.output_type.bind(py); + let input_arg0 = match __implicitly { + Some(t) if !t.is_empty() => Some(t.get_item(0)?), + _ => None, + }; + match input_arg0 { + Some(arg) => call_cls.call1((&self.rule_id, output_type, args, arg)), + None => call_cls.call1((&self.rule_id, output_type, args)), + } + } + + /// Forward unknown attribute lookups (`__name__`, `__qualname__`, `__module__`, + /// `__line_number__`, custom attrs, ...) to the wrapped function so introspection that + /// would've relied on `functools.wraps` still works without a per-instance `__dict__`. + fn __getattr__<'py>(&self, py: Python<'py>, name: &str) -> PyResult> { + self.wrapped.bind(py).getattr(name) + } + + /// `__doc__` lives in the type object's `tp_doc` slot, which shadows `#[getter]` and + /// `__getattr__`. Intercepting at `tp_getattro` via `__getattribute__` is the only hook + /// that fires before the slot is read. See PyO3/pyo3#2187. + fn __getattribute__<'py>( + slf: PyRef<'py, Self>, + name: Bound<'py, PyString>, + ) -> PyResult> { + let py = slf.py(); + if name.to_cow()? == "__doc__" { + return slf.wrapped.bind(py).getattr(intern!(py, "__doc__")); + } + unsafe { + Bound::from_owned_ptr_or_err( + py, + pyo3::ffi::PyObject_GenericGetAttr(slf.as_ptr(), name.as_ptr()), + ) + } + } + + fn __repr__(&self, py: Python) -> PyResult { + let name: String = self + .wrapped + .bind(py) + .getattr(intern!(py, "__name__"))? + .extract()?; + Ok(format!("")) + } +} + pub struct NativeCall { pub call: BoxFuture<'static, Result>, } @@ -609,10 +709,16 @@ pub enum GeneratorResponse { NativeCall(NativeCall), /// The generator is awaiting a call to a known rule. Call(Call), - /// The generator is awaiting calls to a series of generators, all of which will - /// produce `Call`s. + /// The generator is awaiting completion of a series of awaitables. /// - /// The generators used in this position will either be call-by-name `@rule` stubs (which will - /// immediately produce a `Call`, and then return its value), or async "rule helpers". - All(Vec), + /// Each entry is either a rule `Call` (from a direct `RuleCallTrampoline` invocation, no + /// coroutine wrapper) or a generator from an async rule helper. + All(Vec), +} + +pub enum AllItem { + /// A generator produced by an `async def` rule helper; drive it via generator_send. + Generator(Value), + /// A direct `Call` returned by a call-by-name `@rule` invocation; execute it as-is. + Call(Call), } diff --git a/src/rust/engine/src/nodes/task.rs b/src/rust/engine/src/nodes/task.rs index 3e4b4a374f0..00994a399a7 100644 --- a/src/rust/engine/src/nodes/task.rs +++ b/src/rust/engine/src/nodes/task.rs @@ -164,12 +164,17 @@ impl Task { Err(failure) => break Err(failure), } } - GeneratorResponse::All(generators) => { + GeneratorResponse::All(items) => { let _blocking_token = workunit.blocking(); - let get_futures = generators + let get_futures = items .into_iter() - .map(|generator| { - Self::gen_generator(context, params.clone(), entry, generator) + .map(|item| match item { + externs::AllItem::Generator(generator) => { + Self::gen_generator(context, params.clone(), entry, generator) + } + externs::AllItem::Call(call) => { + Self::gen_call(context, params.clone(), entry, call).boxed() + } }) .collect::>(); match future::try_join_all(get_futures).await { From 92b73798b6c39fc86ef9bb72169fdde8b78c0cd6 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Thu, 16 Apr 2026 23:33:49 +0200 Subject: [PATCH 2/8] perf: Remove python subclassing for Call --- .../pants/engine/internals/native_engine.pyi | 12 +- .../pants/engine/internals/scheduler_test.py | 3 - .../pants/engine/internals/selectors.py | 13 +- src/python/pants/engine/rules.py | 4 +- src/rust/engine/src/externs/mod.rs | 114 ++++++++++++++---- 5 files changed, 98 insertions(+), 48 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index 0218b6c07fb..2a5691acd0d 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -6,7 +6,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence from datetime import datetime from enum import Enum from io import RawIOBase @@ -1157,7 +1157,7 @@ def hash_prefix_zero_bits(item: str) -> int: ... _Output = TypeVar("_Output") _Input = TypeVar("_Input") -class PyGeneratorResponseCall: +class Call: rule_id: str output_type: type inputs: Sequence[Any] @@ -1192,10 +1192,11 @@ class PyGeneratorResponseCall: input_arg0: type[_Input] | _Input, input_arg1: _Input | None = None, ) -> None: ... + def __await__(self) -> Generator[Any, None, Any]: ... + def __repr__(self) -> str: ... class RuleCallTrampoline: - """ - The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so + """The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so each invocation constructs the already-awaitable `Call` directly. `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. """ @@ -1209,11 +1210,10 @@ class RuleCallTrampoline: self, rule_id: str, output_type: type, - call_cls: type, wrapped: Callable[..., Any], rule: Any, ) -> None: ... - def __call__(self, *args: Any, __implicitly: Sequence[Any] = (), **kwargs: Any) -> Any: ... + def __call__(self, *args: Any, __implicitly: Sequence[Any] = (), **kwargs: Any) -> Call: ... # ------------------------------------------------------------------------------ # (uncategorized) diff --git a/src/python/pants/engine/internals/scheduler_test.py b/src/python/pants/engine/internals/scheduler_test.py index 262ec77a3f2..2e3f4711584 100644 --- a/src/python/pants/engine/internals/scheduler_test.py +++ b/src/python/pants/engine/internals/scheduler_test.py @@ -474,9 +474,6 @@ def test_trace_includes_nested_exception_traceback() -> None: File LOCATION-INFO, in catch_and_reraise return await raise_an_exception(SomeInput(outer_input.s)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File LOCATION-INFO, in __await__ - result = yield self - ^^^^^^^^^^ File LOCATION-INFO, in raise_an_exception raise Exception(some_input.s) Exception: asdf diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index 9b1e6e80c4b..c365e0f14f0 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import Any, TypeVar, cast, overload -from pants.engine.internals.native_engine import PyGeneratorResponseCall +from pants.engine.internals.native_engine import Call as Call # noqa: F401 from pants.util.strutil import softwrap _Output = TypeVar("_Output") @@ -30,17 +30,6 @@ def __str__(self) -> str: return repr(self) -class Call(PyGeneratorResponseCall): - def __await__( - self, - ) -> Generator[Any, None, Any]: - result = yield self - return result - - def __repr__(self) -> str: - return f"Call({self.rule_id}(...) -> {self.output_type.__name__})" - - @dataclass(frozen=True) class _Concurrently: # A call-by-name `@rule` invocation returns a `Call` directly; async helpers return diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index 1fcc415e644..ea0d5e30f49 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -26,7 +26,7 @@ from pants.engine.engine_aware import SideEffecting from pants.engine.internals.native_engine import RuleCallTrampoline from pants.engine.internals.rule_visitor import collect_awaitables -from pants.engine.internals.selectors import AwaitableConstraints, Call +from pants.engine.internals.selectors import AwaitableConstraints from pants.engine.internals.selectors import concurrently as concurrently # noqa: F401 from pants.engine.unions import UnionRule from pants.util.frozendict import FrozenDict @@ -120,7 +120,7 @@ def wrapper(original_func): ) return cast( Callable[P, R], - RuleCallTrampoline(canonical_name, return_type, Call, original_func, task_rule), + RuleCallTrampoline(canonical_name, return_type, original_func, task_rule), ) return wrapper diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index dc1fb246c20..82327b1c3f2 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -481,7 +481,7 @@ impl PyGeneratorResponseNativeCall { } } -#[pyclass(subclass)] +#[pyclass(name = "Call", module = "pants.engine.internals.native_engine")] pub struct PyGeneratorResponseCall(RwLock>); impl PyGeneratorResponseCall { @@ -496,15 +496,10 @@ impl PyGeneratorResponseCall { )) } } -} -#[pymethods] -impl PyGeneratorResponseCall { - #[new] - #[pyo3(signature = (rule_id, output_type, args, input_arg0=None))] - fn __new__( - py: Python, - rule_id: String, + pub(crate) fn construct( + py: Python<'_>, + rule_id: &str, output_type: &Bound<'_, PyType>, args: &Bound<'_, PyTuple>, input_arg0: Option>, @@ -526,7 +521,7 @@ impl PyGeneratorResponseCall { let (input_types, inputs) = interpret_implicit_args(py, input_arg0)?; Ok(Self(RwLock::new(Some(Call { - rule_id: RuleId::from_string(rule_id), + rule_id: RuleId::new(rule_id), output_type, args, args_arity, @@ -534,6 +529,21 @@ impl PyGeneratorResponseCall { inputs, })))) } +} + +#[pymethods] +impl PyGeneratorResponseCall { + #[new] + #[pyo3(signature = (rule_id, output_type, args, input_arg0=None))] + fn __new__( + py: Python, + rule_id: PyBackedStr, + output_type: &Bound<'_, PyType>, + args: &Bound<'_, PyTuple>, + input_arg0: Option>, + ) -> PyResult { + Self::construct(py, &rule_id, output_type, args, input_arg0) + } #[getter] fn rule_id(&self, py: Python) -> PyResult { @@ -560,6 +570,61 @@ impl PyGeneratorResponseCall { .chain(inner.inputs.iter().map(Key::to_py_object)) .collect()) } + + fn __await__(slf: Py) -> CallAwaitable { + CallAwaitable(Mutex::new(Some(slf))) + } + + fn __repr__(&self, py: Python) -> PyResult { + let inner = self.borrow_inner(py)?; + let output_type_name = inner.output_type.as_py_type(py).name()?.to_string(); + Ok(format!( + "Call({}(...) -> {output_type_name})", + inner.rule_id + )) + } +} + +/// The iterator returned by `PyGeneratorResponseCall.__await__`. Yields the Call to the +/// scheduler once, then raises `StopIteration(result)` when the scheduler sends the +/// rule's result back through `send`. +#[pyclass(frozen, module = "pants.engine.internals.native_engine")] +pub struct CallAwaitable(Mutex>>); + +#[pymethods] +impl CallAwaitable { + fn __iter__(slf: Py) -> Py { + slf + } + + fn __next__(&self, py: Python<'_>) -> PyResult> { + self.send(py, None) + } + + #[pyo3(signature = (value=None))] + fn send( + &self, + py: Python<'_>, + value: Option>, + ) -> PyResult> { + let mut state = self.0.lock_py_attached(py); + if state.is_some() && matches!(&value, Some(v) if !v.is_none(py)) { + return Err(PyTypeError::new_err( + "can't send non-None value to a just-started generator", + )); + } + state + .take() + .ok_or_else(|| PyStopIteration::new_err((value.unwrap_or_else(|| py.None()),))) + } + + fn throw(&self, exc: Bound<'_, PyAny>) -> PyResult> { + Err(PyErr::from_value(exc)) + } + + fn close(&self, py: Python<'_>) { + *self.0.lock_py_attached(py) = None; + } } impl PyGeneratorResponseCall { @@ -579,7 +644,6 @@ pub struct RuleCallTrampoline { rule_id: PyBackedStr, #[pyo3(get)] output_type: Py, - call_cls: Py, #[pyo3(get, name = "__wrapped__")] wrapped: Py, #[pyo3(get)] @@ -592,14 +656,12 @@ impl RuleCallTrampoline { fn __new__( rule_id: PyBackedStr, output_type: Py, - call_cls: Py, wrapped: Py, rule: Py, ) -> Self { Self { rule_id, output_type, - call_cls, wrapped, rule, } @@ -611,23 +673,25 @@ impl RuleCallTrampoline { } #[pyo3(signature = (*args, __implicitly=None, **_kwargs))] - fn __call__<'py>( + fn __call__( &self, - py: Python<'py>, - args: &Bound<'py, PyTuple>, - __implicitly: Option<&Bound<'py, PyTuple>>, - _kwargs: Option<&Bound<'py, PyDict>>, - ) -> PyResult> { - let call_cls = self.call_cls.bind(py); - let output_type = self.output_type.bind(py); + py: Python<'_>, + args: &Bound<'_, PyTuple>, + __implicitly: Option<&Bound<'_, PyTuple>>, + _kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult> { let input_arg0 = match __implicitly { Some(t) if !t.is_empty() => Some(t.get_item(0)?), _ => None, }; - match input_arg0 { - Some(arg) => call_cls.call1((&self.rule_id, output_type, args, arg)), - None => call_cls.call1((&self.rule_id, output_type, args)), - } + let call = PyGeneratorResponseCall::construct( + py, + &self.rule_id, + self.output_type.bind(py), + args, + input_arg0, + )?; + Py::new(py, call) } /// Forward unknown attribute lookups (`__name__`, `__qualname__`, `__module__`, From adb6685faaf877fb161085110aec04b05e420839 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Thu, 16 Apr 2026 23:53:35 +0200 Subject: [PATCH 3/8] perf: Remove python async overhead from Concurrently calls --- .../pants/engine/internals/native_engine.pyi | 9 +- .../pants/engine/internals/selectors.py | 200 +++++++++--------- src/rust/engine/src/externs/mod.rs | 68 ++++-- 3 files changed, 156 insertions(+), 121 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index 2a5691acd0d..bc643f2db65 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -6,7 +6,7 @@ from __future__ import annotations -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Coroutine, Generator, Iterable, Iterator, Mapping, Sequence from datetime import datetime from enum import Enum from io import RawIOBase @@ -1195,6 +1195,13 @@ class Call: def __await__(self) -> Generator[Any, None, Any]: ... def __repr__(self) -> str: ... +class _Concurrently: + calls: tuple[Coroutine[Any, Any, Any] | Call, ...] + def __init__(self, calls: tuple[Coroutine[Any, Any, Any] | Call, ...]) -> None: ... + def __await__( + self, + ) -> Generator[tuple[Coroutine[Any, Any, Any] | Call, ...], None, tuple[Any, ...]]: ... + class RuleCallTrampoline: """The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so each invocation constructs the already-awaitable `Call` directly. diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index c365e0f14f0..d9c18a9f173 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -4,11 +4,12 @@ from __future__ import annotations import itertools -from collections.abc import Coroutine, Generator, Iterable +from collections.abc import Awaitable, Coroutine, Iterable from dataclasses import dataclass -from typing import Any, TypeVar, cast, overload +from typing import Any, TypeVar, overload from pants.engine.internals.native_engine import Call as Call # noqa: F401 +from pants.engine.internals.native_engine import _Concurrently from pants.util.strutil import softwrap _Output = TypeVar("_Output") @@ -30,17 +31,6 @@ def __str__(self) -> str: return repr(self) -@dataclass(frozen=True) -class _Concurrently: - # A call-by-name `@rule` invocation returns a `Call` directly; async helpers return - # `Coroutine`. The engine accepts both in this sequence. - calls: tuple[Coroutine | Call, ...] - - def __await__(self) -> Generator[tuple[Coroutine | Call, ...], None, tuple]: - result = yield self.calls - return cast(tuple, result) - - # These type variables are used to parametrize from 1 to 10 args when used in a tuple-style # concurrently() call. @@ -57,13 +47,13 @@ def __await__(self) -> Generator[tuple[Coroutine | Call, ...], None, tuple]: @overload -async def Concurrently( +def Concurrently( __gets: Iterable[Coroutine[Any, Any, _Output]], -) -> tuple[_Output, ...]: ... +) -> Awaitable[tuple[_Output, ...]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Output], __get1: Coroutine[Any, Any, _Output], __get2: Coroutine[Any, Any, _Output], @@ -76,11 +66,11 @@ async def Concurrently( __get9: Coroutine[Any, Any, _Output], __get10: Coroutine[Any, Any, _Output], *__gets: Coroutine[Any, Any, _Output], -) -> tuple[_Output, ...]: ... +) -> Awaitable[tuple[_Output, ...]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], @@ -91,11 +81,11 @@ async def Concurrently( __get7: Coroutine[Any, Any, _Out7], __get8: Coroutine[Any, Any, _Out8], __get9: Coroutine[Any, Any, _Out9], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], @@ -105,11 +95,11 @@ async def Concurrently( __get6: Coroutine[Any, Any, _Out6], __get7: Coroutine[Any, Any, _Out7], __get8: Coroutine[Any, Any, _Out8], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], @@ -118,11 +108,11 @@ async def Concurrently( __get5: Coroutine[Any, Any, _Out5], __get6: Coroutine[Any, Any, _Out6], __get7: Coroutine[Any, Any, _Out7], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], @@ -130,55 +120,55 @@ async def Concurrently( __get4: Coroutine[Any, Any, _Out4], __get5: Coroutine[Any, Any, _Out5], __get6: Coroutine[Any, Any, _Out6], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], __get3: Coroutine[Any, Any, _Out3], __get4: Coroutine[Any, Any, _Out4], __get5: Coroutine[Any, Any, _Out5], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], __get3: Coroutine[Any, Any, _Out3], __get4: Coroutine[Any, Any, _Out4], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], __get3: Coroutine[Any, Any, _Out3], -) -> tuple[_Out0, _Out1, _Out2, _Out3]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], __get2: Coroutine[Any, Any, _Out2], -) -> tuple[_Out0, _Out1, _Out2]: ... +) -> Awaitable[tuple[_Out0, _Out1, _Out2]]: ... @overload -async def Concurrently( +def Concurrently( __get0: Coroutine[Any, Any, _Out0], __get1: Coroutine[Any, Any, _Out1], -) -> tuple[_Out0, _Out1]: ... +) -> Awaitable[tuple[_Out0, _Out1]]: ... -async def Concurrently( +def Concurrently( __arg0: (Iterable[Coroutine[Any, Any, _Output]] | Coroutine[Any, Any, _Out0]), __arg1: Coroutine[Any, Any, _Out1] | None = None, __arg2: Coroutine[Any, Any, _Out2] | None = None, @@ -190,7 +180,7 @@ async def Concurrently( __arg8: Coroutine[Any, Any, _Out8] | None = None, __arg9: Coroutine[Any, Any, _Out9] | None = None, *__args: Coroutine[Any, Any, _Output], -) -> ( +) -> Awaitable[ tuple[_Output, ...] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8] @@ -202,7 +192,7 @@ async def Concurrently( | tuple[_Out0, _Out1, _Out2] | tuple[_Out0, _Out1] | tuple[_Out0] -): +]: """Yield a tuple of Coroutine instances all at once. The `yield`ed value `self.calls` is interpreted by the engine within @@ -226,10 +216,10 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently(tuple(__arg0)) + return _Concurrently(tuple(__arg0)) if ( - isinstance(__arg0, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) and __arg1 is None and __arg2 is None and __arg3 is None @@ -241,11 +231,11 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0,)) + return _Concurrently((__arg0,)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) and __arg2 is None and __arg3 is None and __arg4 is None @@ -256,12 +246,12 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1)) + return _Concurrently((__arg0, __arg1)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) and __arg3 is None and __arg4 is None and __arg5 is None @@ -271,13 +261,13 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2)) + return _Concurrently((__arg0, __arg1, __arg2)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) and __arg4 is None and __arg5 is None and __arg6 is None @@ -286,14 +276,14 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) - and isinstance(__arg4, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) and __arg5 is None and __arg6 is None and __arg7 is None @@ -301,84 +291,84 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) - and isinstance(__arg4, (Coroutine, Call)) - and isinstance(__arg5, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) and __arg6 is None and __arg7 is None and __arg8 is None and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) - and isinstance(__arg4, (Coroutine, Call)) - and isinstance(__arg5, (Coroutine, Call)) - and isinstance(__arg6, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) and __arg7 is None and __arg8 is None and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) - and isinstance(__arg4, (Coroutine, Call)) - and isinstance(__arg5, (Coroutine, Call)) - and isinstance(__arg6, (Coroutine, Call)) - and isinstance(__arg7, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) + and isinstance(__arg7, (Coroutine, Call, _Concurrently)) and __arg8 is None and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7)) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) - and isinstance(__arg4, (Coroutine, Call)) - and isinstance(__arg5, (Coroutine, Call)) - and isinstance(__arg6, (Coroutine, Call)) - and isinstance(__arg7, (Coroutine, Call)) - and isinstance(__arg8, (Coroutine, Call)) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) + and isinstance(__arg7, (Coroutine, Call, _Concurrently)) + and isinstance(__arg8, (Coroutine, Call, _Concurrently)) and __arg9 is None and not __args ): - return await _Concurrently( + return _Concurrently( (__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7, __arg8) ) if ( - isinstance(__arg0, (Coroutine, Call)) - and isinstance(__arg1, (Coroutine, Call)) - and isinstance(__arg2, (Coroutine, Call)) - and isinstance(__arg3, (Coroutine, Call)) - and isinstance(__arg4, (Coroutine, Call)) - and isinstance(__arg5, (Coroutine, Call)) - and isinstance(__arg6, (Coroutine, Call)) - and isinstance(__arg7, (Coroutine, Call)) - and isinstance(__arg8, (Coroutine, Call)) - and isinstance(__arg9, (Coroutine, Call)) - and all(isinstance(arg, (Coroutine, Call)) for arg in __args) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) + and isinstance(__arg7, (Coroutine, Call, _Concurrently)) + and isinstance(__arg8, (Coroutine, Call, _Concurrently)) + and isinstance(__arg9, (Coroutine, Call, _Concurrently)) + and all(isinstance(arg, (Coroutine, Call, _Concurrently)) for arg in __args) ): - return await _Concurrently( + return _Concurrently( ( __arg0, __arg1, diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 82327b1c3f2..3ea4fc4b274 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -51,6 +51,7 @@ pub fn register(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add("EngineError", py.get_type::())?; @@ -356,8 +357,11 @@ pub(crate) fn generator_send( } else if let Ok(call) = response.extract::>() { Ok(GeneratorResponse::NativeCall(call.take(py)?)) } else if let Ok(get_multi) = response.cast::() { - // Was an `All` or `concurrently`. Each item is either a generator (from an async - // helper) or a direct `Call` (from a call-by-name `@rule` invocation). + // Was an `All` or `concurrently`. Each item is one of: + // * a generator (async helper) — drive via generator_send; + // * a `Call` — dispatch directly via `gen_call`; + // * a `_Concurrently` — nested `concurrently(...)`. Treat its awaiter as a + // generator so the outer engine loop recurses into the inner tuple. let items = get_multi .try_iter()? .map(|item| { @@ -368,6 +372,9 @@ pub(crate) fn generator_send( call.take(py) .map(AllItem::Call) .map_err(PyValueError::new_err) + } else if let Ok(concurrently) = item.extract::>() { + let awaiter = Py::new(py, concurrently.awaiter(py))?; + Ok(AllItem::Generator(Value::new(awaiter.into_any()))) } else { Err(PyValueError::new_err(format!( "Expected an `All` or `concurrently` to receive calls to rules, \ @@ -571,8 +578,8 @@ impl PyGeneratorResponseCall { .collect()) } - fn __await__(slf: Py) -> CallAwaitable { - CallAwaitable(Mutex::new(Some(slf))) + fn __await__(slf: Py) -> YieldOnce { + YieldOnce::new(slf) } fn __repr__(&self, py: Python) -> PyResult { @@ -585,28 +592,30 @@ impl PyGeneratorResponseCall { } } -/// The iterator returned by `PyGeneratorResponseCall.__await__`. Yields the Call to the -/// scheduler once, then raises `StopIteration(result)` when the scheduler sends the -/// rule's result back through `send`. +/// A generator-protocol iterator that yields one value to the engine on the first `send`, +/// then raises `StopIteration(result)` when the engine sends the result back. Used to +/// implement `Call.__await__` and `_Concurrently.__await__`. #[pyclass(frozen, module = "pants.engine.internals.native_engine")] -pub struct CallAwaitable(Mutex>>); +pub struct YieldOnce(Mutex>>); + +impl YieldOnce { + fn new(value: Py) -> Self { + Self(Mutex::new(Some(value.into_any()))) + } +} #[pymethods] -impl CallAwaitable { +impl YieldOnce { fn __iter__(slf: Py) -> Py { slf } - fn __next__(&self, py: Python<'_>) -> PyResult> { + fn __next__(&self, py: Python<'_>) -> PyResult> { self.send(py, None) } #[pyo3(signature = (value=None))] - fn send( - &self, - py: Python<'_>, - value: Option>, - ) -> PyResult> { + fn send(&self, py: Python<'_>, value: Option>) -> PyResult> { let mut state = self.0.lock_py_attached(py); if state.is_some() && matches!(&value, Some(v) if !v.is_none(py)) { return Err(PyTypeError::new_err( @@ -627,6 +636,35 @@ impl CallAwaitable { } } +/// Yielded by `concurrently(...)` to hand a tuple of awaitables (each either a `Call` or a +/// coroutine from an async helper) to the engine for parallel execution. +#[pyclass( + frozen, + name = "_Concurrently", + module = "pants.engine.internals.native_engine" +)] +pub struct PyConcurrently { + calls: Py, +} + +impl PyConcurrently { + pub(crate) fn awaiter(&self, py: Python<'_>) -> YieldOnce { + YieldOnce::new(self.calls.clone_ref(py)) + } +} + +#[pymethods] +impl PyConcurrently { + #[new] + fn __new__(calls: Py) -> Self { + Self { calls } + } + + fn __await__(&self, py: Python<'_>) -> YieldOnce { + self.awaiter(py) + } +} + impl PyGeneratorResponseCall { fn take(&self, py: Python<'_>) -> Result { self.0 From 42a091aaaf7bcd15257c278d0f57f7508cc6edbd Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Fri, 17 Apr 2026 00:08:46 +0200 Subject: [PATCH 4/8] test: Fix mock impl --- .../pants/engine/internals/native_engine.pyi | 3 +- src/python/pants/testutil/rule_runner.py | 40 ++++++------------- src/rust/engine/src/externs/mod.rs | 26 +++++++----- 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index bc643f2db65..f4c9f5719dc 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -1160,7 +1160,8 @@ _Input = TypeVar("_Input") class Call: rule_id: str output_type: type - inputs: Sequence[Any] + args: tuple[Any, ...] + implicit_args: dict[Any, type] @overload def __init__( diff --git a/src/python/pants/testutil/rule_runner.py b/src/python/pants/testutil/rule_runner.py index af776ca8055..3b2bf18c611 100644 --- a/src/python/pants/testutil/rule_runner.py +++ b/src/python/pants/testutil/rule_runner.py @@ -710,34 +710,20 @@ def run_rule_with_mocks( unconsumed_mock_calls = set(mock_calls.keys()) - def get(res: Call | Coroutine): - if isinstance(res, Coroutine): - # A call-by-name element in a concurrently() is a Coroutine whose frame is - # the trampoline wrapper that creates and immediately awaits the Call. - locals = inspect.getcoroutinelocals(res) - assert locals is not None - rule_id = locals["rule_id"] - args = locals["args"] - kwargs = dict(locals["kwargs"]) - __implicitly = locals.get("__implicitly") - if __implicitly: - kwargs["__implicitly"] = __implicitly - mock_call = mock_calls.get(rule_id) - if mock_call: - unconsumed_mock_calls.discard(rule_id) - # Close the original, unmocked, coroutine, to prevent the "was never awaited" - # warning polluting stderr data that the test may examine. - res.close() - return mock_call(*args, **kwargs) - raise AssertionError(f"No mock_call provided for {rule_id}.") - elif isinstance(res, Call): - mock_call = mock_calls.get(res.rule_id) - if mock_call: - unconsumed_mock_calls.discard(res.rule_id) - return mock_call(*res.inputs) - raise AssertionError(f"No mock_call provided for {res.rule_id}.") - else: + def get(res: Call): + if not isinstance(res, Call): raise AssertionError(f"Bad arg type: {res}") + mock_call = mock_calls.get(res.rule_id) + if mock_call is None: + raise AssertionError(f"No mock_call provided for {res.rule_id}.") + unconsumed_mock_calls.discard(res.rule_id) + # NB: if the mock declares an `__implicitly` parameter, forward the raw `(dict,)` so it + # can inspect declared types (e.g. to route polymorphic dispatch); otherwise unpack the + # implicit values positionally. + implicit = res.implicit_args + if implicit and "__implicitly" in inspect.signature(mock_call).parameters: + return mock_call(*res.args, __implicitly=(implicit,)) + return mock_call(*res.args, *implicit) rule_coroutine = res rule_input = None diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 3ea4fc4b274..668913133f1 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -566,16 +566,24 @@ impl PyGeneratorResponseCall { } #[getter] - fn inputs(&self, py: Python<'_>) -> PyResult>> { + fn args<'py>(&self, py: Python<'py>) -> PyResult> { let inner = self.borrow_inner(py)?; - let args: Vec> = inner.args.as_ref().map_or_else( - || Ok(Vec::default()), - |args| args.to_py_object().extract(py), - )?; - Ok(args - .into_iter() - .chain(inner.inputs.iter().map(Key::to_py_object)) - .collect()) + match &inner.args { + Some(args) => Ok(args.to_py_object().extract(py)?), + None => Ok(PyTuple::empty(py)), + } + } + + /// NB: keyed on the value, mirroring the `{value: declared_type}` shape of `implicitly(...)` + /// at the call site so test mocks can read declared types back off the dict. + #[getter] + fn implicit_args<'py>(&self, py: Python<'py>) -> PyResult> { + let inner = self.borrow_inner(py)?; + let d = PyDict::new(py); + for (typ, val) in inner.input_types.iter().zip(inner.inputs.iter()) { + d.set_item(val.to_py_object(), typ.as_py_type(py))?; + } + Ok(d) } fn __await__(slf: Py) -> YieldOnce { From 4937324a13589714f6078fda72e8ad77771012f0 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Fri, 17 Apr 2026 11:46:56 +0200 Subject: [PATCH 5/8] typing: Accurately describe supported inputs And correctly propagate outputs --- .../pants/engine/internals/native_engine.pyi | 13 +- .../pants/engine/internals/selectors.py | 186 +++++++++--------- src/rust/engine/src/externs/mod.rs | 1 + 3 files changed, 105 insertions(+), 95 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index f4c9f5719dc..b37bcb674b8 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -1155,6 +1155,7 @@ def hash_prefix_zero_bits(item: str) -> int: ... # ------------------------------------------------------------------------------ _Output = TypeVar("_Output") +_Output_co = TypeVar("_Output_co", covariant=True) _Input = TypeVar("_Input") class Call: @@ -1196,12 +1197,16 @@ class Call: def __await__(self) -> Generator[Any, None, Any]: ... def __repr__(self) -> str: ... -class _Concurrently: - calls: tuple[Coroutine[Any, Any, Any] | Call, ...] - def __init__(self, calls: tuple[Coroutine[Any, Any, Any] | Call, ...]) -> None: ... +class _Concurrently(Generic[_Output_co]): + calls: tuple[Coroutine[Any, Any, Any] | Call | _Concurrently[Any], ...] + def __init__( + self, calls: tuple[Coroutine[Any, Any, Any] | Call | _Concurrently[Any], ...] + ) -> None: ... def __await__( self, - ) -> Generator[tuple[Coroutine[Any, Any, Any] | Call, ...], None, tuple[Any, ...]]: ... + ) -> Generator[ + tuple[Coroutine[Any, Any, Any] | Call | _Concurrently[Any], ...], None, _Output_co + ]: ... class RuleCallTrampoline: """The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index d9c18a9f173..e7d0c1291d5 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -4,7 +4,7 @@ from __future__ import annotations import itertools -from collections.abc import Awaitable, Coroutine, Iterable +from collections.abc import Coroutine, Iterable from dataclasses import dataclass from typing import Any, TypeVar, overload @@ -48,139 +48,143 @@ def __str__(self) -> str: @overload def Concurrently( - __gets: Iterable[Coroutine[Any, Any, _Output]], -) -> Awaitable[tuple[_Output, ...]]: ... + __gets: Iterable[Coroutine[Any, Any, _Output] | _Concurrently[_Output]], +) -> _Concurrently[tuple[_Output, ...]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Output], - __get1: Coroutine[Any, Any, _Output], - __get2: Coroutine[Any, Any, _Output], - __get3: Coroutine[Any, Any, _Output], - __get4: Coroutine[Any, Any, _Output], - __get5: Coroutine[Any, Any, _Output], - __get6: Coroutine[Any, Any, _Output], - __get7: Coroutine[Any, Any, _Output], - __get8: Coroutine[Any, Any, _Output], - __get9: Coroutine[Any, Any, _Output], - __get10: Coroutine[Any, Any, _Output], - *__gets: Coroutine[Any, Any, _Output], -) -> Awaitable[tuple[_Output, ...]]: ... + __get0: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get1: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get2: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get3: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get4: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get5: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get6: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get7: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get8: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get9: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get10: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + *__gets: Coroutine[Any, Any, _Output] | _Concurrently[_Output], +) -> _Concurrently[tuple[_Output, ...]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], - __get7: Coroutine[Any, Any, _Out7], - __get8: Coroutine[Any, Any, _Out8], - __get9: Coroutine[Any, Any, _Out9], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7], + __get8: Coroutine[Any, Any, _Out8] | _Concurrently[_Out8], + __get9: Coroutine[Any, Any, _Out9] | _Concurrently[_Out9], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], - __get7: Coroutine[Any, Any, _Out7], - __get8: Coroutine[Any, Any, _Out8], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7], + __get8: Coroutine[Any, Any, _Out8] | _Concurrently[_Out8], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], - __get7: Coroutine[Any, Any, _Out7], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3, _Out4]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], -) -> Awaitable[tuple[_Out0, _Out1, _Out2, _Out3]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], -) -> Awaitable[tuple[_Out0, _Out1, _Out2]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], -) -> Awaitable[tuple[_Out0, _Out1]]: ... + __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], +) -> _Concurrently[tuple[_Out0, _Out1]]: ... def Concurrently( - __arg0: (Iterable[Coroutine[Any, Any, _Output]] | Coroutine[Any, Any, _Out0]), - __arg1: Coroutine[Any, Any, _Out1] | None = None, - __arg2: Coroutine[Any, Any, _Out2] | None = None, - __arg3: Coroutine[Any, Any, _Out3] | None = None, - __arg4: Coroutine[Any, Any, _Out4] | None = None, - __arg5: Coroutine[Any, Any, _Out5] | None = None, - __arg6: Coroutine[Any, Any, _Out6] | None = None, - __arg7: Coroutine[Any, Any, _Out7] | None = None, - __arg8: Coroutine[Any, Any, _Out8] | None = None, - __arg9: Coroutine[Any, Any, _Out9] | None = None, - *__args: Coroutine[Any, Any, _Output], -) -> Awaitable[ + __arg0: ( + Iterable[Coroutine[Any, Any, _Output] | _Concurrently[_Output]] + | Coroutine[Any, Any, _Out0] + | _Concurrently[_Out0] + ), + __arg1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1] | None = None, + __arg2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2] | None = None, + __arg3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3] | None = None, + __arg4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4] | None = None, + __arg5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5] | None = None, + __arg6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6] | None = None, + __arg7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7] | None = None, + __arg8: Coroutine[Any, Any, _Out8] | _Concurrently[_Out8] | None = None, + __arg9: Coroutine[Any, Any, _Out9] | _Concurrently[_Out9] | None = None, + *__args: Coroutine[Any, Any, _Output] | _Concurrently[_Output], +) -> _Concurrently[ tuple[_Output, ...] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8] diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 668913133f1..31108fa22d4 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -648,6 +648,7 @@ impl YieldOnce { /// coroutine from an async helper) to the engine for parallel execution. #[pyclass( frozen, + generic, name = "_Concurrently", module = "pants.engine.internals.native_engine" )] From 4dc0d526d8dc51ef02c4b8b4a6b7aaee94ceeda9 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Fri, 17 Apr 2026 12:34:59 +0200 Subject: [PATCH 6/8] typing: Transparently return generic-Call It is not a coroutine --- .../pants/engine/internals/native_engine.pyi | 20 ++- .../pants/engine/internals/selectors.py | 157 +++++++++--------- src/python/pants/engine/rules.py | 10 +- src/rust/engine/src/externs/mod.rs | 8 +- 4 files changed, 101 insertions(+), 94 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index b37bcb674b8..1dbfccbc358 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -1158,7 +1158,7 @@ _Output = TypeVar("_Output") _Output_co = TypeVar("_Output_co", covariant=True) _Input = TypeVar("_Input") -class Call: +class Call(Generic[_Output_co]): rule_id: str output_type: type args: tuple[Any, ...] @@ -1194,39 +1194,41 @@ class Call: input_arg0: type[_Input] | _Input, input_arg1: _Input | None = None, ) -> None: ... - def __await__(self) -> Generator[Any, None, Any]: ... + def __await__(self) -> Generator[Any, None, _Output_co]: ... def __repr__(self) -> str: ... class _Concurrently(Generic[_Output_co]): - calls: tuple[Coroutine[Any, Any, Any] | Call | _Concurrently[Any], ...] + calls: tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...] def __init__( - self, calls: tuple[Coroutine[Any, Any, Any] | Call | _Concurrently[Any], ...] + self, calls: tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...] ) -> None: ... def __await__( self, ) -> Generator[ - tuple[Coroutine[Any, Any, Any] | Call | _Concurrently[Any], ...], None, _Output_co + tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...], None, _Output_co ]: ... -class RuleCallTrampoline: +class RuleCallTrampoline(Generic[_Output]): """The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so each invocation constructs the already-awaitable `Call` directly. `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. """ rule_id: str - output_type: type + output_type: type[_Output] rule: Any __wrapped__: Callable[..., Any] def __init__( self, rule_id: str, - output_type: type, + output_type: type[_Output], wrapped: Callable[..., Any], rule: Any, ) -> None: ... - def __call__(self, *args: Any, __implicitly: Sequence[Any] = (), **kwargs: Any) -> Call: ... + def __call__( + self, *args: Any, __implicitly: Sequence[Any] = (), **kwargs: Any + ) -> Call[_Output]: ... # ------------------------------------------------------------------------------ # (uncategorized) diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index e7d0c1291d5..2aa47a9f250 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -48,142 +48,143 @@ def __str__(self) -> str: @overload def Concurrently( - __gets: Iterable[Coroutine[Any, Any, _Output] | _Concurrently[_Output]], + __gets: Iterable[Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output]], ) -> _Concurrently[tuple[_Output, ...]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get1: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get2: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get3: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get4: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get5: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get6: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get7: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get8: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get9: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - __get10: Coroutine[Any, Any, _Output] | _Concurrently[_Output], - *__gets: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __get0: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get1: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get2: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get3: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get4: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get5: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get6: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get7: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get8: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get9: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get10: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + *__gets: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], ) -> _Concurrently[tuple[_Output, ...]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], - __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], - __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], - __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], - __get7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7], - __get8: Coroutine[Any, Any, _Out8] | _Concurrently[_Out8], - __get9: Coroutine[Any, Any, _Out9] | _Concurrently[_Out9], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7], + __get8: Coroutine[Any, Any, _Out8] | Call[_Out8] | _Concurrently[_Out8], + __get9: Coroutine[Any, Any, _Out9] | Call[_Out9] | _Concurrently[_Out9], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], - __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], - __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], - __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], - __get7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7], - __get8: Coroutine[Any, Any, _Out8] | _Concurrently[_Out8], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7], + __get8: Coroutine[Any, Any, _Out8] | Call[_Out8] | _Concurrently[_Out8], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], - __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], - __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], - __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], - __get7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], - __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], - __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], - __get6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], - __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], - __get5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], - __get4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], - __get3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], - __get2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], ) -> _Concurrently[tuple[_Out0, _Out1, _Out2]]: ... @overload def Concurrently( - __get0: Coroutine[Any, Any, _Out0] | _Concurrently[_Out0], - __get1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1], + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], ) -> _Concurrently[tuple[_Out0, _Out1]]: ... def Concurrently( __arg0: ( - Iterable[Coroutine[Any, Any, _Output] | _Concurrently[_Output]] + Iterable[Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output]] | Coroutine[Any, Any, _Out0] + | Call[_Out0] | _Concurrently[_Out0] ), - __arg1: Coroutine[Any, Any, _Out1] | _Concurrently[_Out1] | None = None, - __arg2: Coroutine[Any, Any, _Out2] | _Concurrently[_Out2] | None = None, - __arg3: Coroutine[Any, Any, _Out3] | _Concurrently[_Out3] | None = None, - __arg4: Coroutine[Any, Any, _Out4] | _Concurrently[_Out4] | None = None, - __arg5: Coroutine[Any, Any, _Out5] | _Concurrently[_Out5] | None = None, - __arg6: Coroutine[Any, Any, _Out6] | _Concurrently[_Out6] | None = None, - __arg7: Coroutine[Any, Any, _Out7] | _Concurrently[_Out7] | None = None, - __arg8: Coroutine[Any, Any, _Out8] | _Concurrently[_Out8] | None = None, - __arg9: Coroutine[Any, Any, _Out9] | _Concurrently[_Out9] | None = None, - *__args: Coroutine[Any, Any, _Output] | _Concurrently[_Output], + __arg1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1] | None = None, + __arg2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2] | None = None, + __arg3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3] | None = None, + __arg4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4] | None = None, + __arg5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5] | None = None, + __arg6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6] | None = None, + __arg7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7] | None = None, + __arg8: Coroutine[Any, Any, _Out8] | Call[_Out8] | _Concurrently[_Out8] | None = None, + __arg9: Coroutine[Any, Any, _Out9] | Call[_Out9] | _Concurrently[_Out9] | None = None, + *__args: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], ) -> _Concurrently[ tuple[_Output, ...] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9] diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index ea0d5e30f49..fd03418d6d7 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -24,7 +24,7 @@ from typing_extensions import ParamSpec from pants.engine.engine_aware import SideEffecting -from pants.engine.internals.native_engine import RuleCallTrampoline +from pants.engine.internals.native_engine import Call, RuleCallTrampoline from pants.engine.internals.rule_visitor import collect_awaitables from pants.engine.internals.selectors import AwaitableConstraints from pants.engine.internals.selectors import concurrently as concurrently # noqa: F401 @@ -54,7 +54,7 @@ class RuleType(Enum): R = TypeVar("R") SyncRuleT = Callable[P, R] AsyncRuleT = Callable[P, Coroutine[Any, Any, R]] -RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], AsyncRuleT] +RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], Callable[P, Call[R]]] def _make_rule( @@ -119,7 +119,7 @@ def wrapper(original_func): polymorphic=polymorphic, ) return cast( - Callable[P, R], + Callable[P, Call[R]], RuleCallTrampoline(canonical_name, return_type, original_func, task_rule), ) @@ -253,7 +253,7 @@ class _RuleDecoratorKwargs(RuleDecoratorKwargs): def rule_decorator( func: SyncRuleT | AsyncRuleT, **kwargs: Unpack[_RuleDecoratorKwargs] -) -> AsyncRuleT: +) -> Callable[P, Call[R]]: if not inspect.isfunction(func): raise ValueError("The @rule decorator expects to be placed on a function.") @@ -388,7 +388,7 @@ def validate_requirements( ) -def inner_rule(*args, **kwargs) -> AsyncRuleT | RuleDecorator: +def inner_rule(*args, **kwargs) -> Callable[P, Call[R]] | RuleDecorator: if len(args) == 1 and inspect.isfunction(args[0]): return rule_decorator(*args, **kwargs) else: diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 31108fa22d4..256f6e41b31 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -488,7 +488,11 @@ impl PyGeneratorResponseNativeCall { } } -#[pyclass(name = "Call", module = "pants.engine.internals.native_engine")] +#[pyclass( + generic, + name = "Call", + module = "pants.engine.internals.native_engine" +)] pub struct PyGeneratorResponseCall(RwLock>); impl PyGeneratorResponseCall { @@ -686,7 +690,7 @@ impl PyGeneratorResponseCall { /// The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so /// each invocation constructs the already-awaitable `Call` directly. /// `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. -#[pyclass(frozen, module = "pants.engine.internals.native_engine")] +#[pyclass(frozen, generic, module = "pants.engine.internals.native_engine")] pub struct RuleCallTrampoline { rule_id: PyBackedStr, #[pyo3(get)] From 6a8601ce129cae1e013b39aa500ef6389cfa10c3 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Fri, 17 Apr 2026 12:58:00 +0200 Subject: [PATCH 7/8] refactor: Concretely recurse into concurrently Previous method of "re-generatoring" was a hack --- src/rust/engine/src/externs/mod.rs | 53 ++++++++++++++++-------------- src/rust/engine/src/nodes/task.rs | 50 +++++++++++++++++++++------- 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 256f6e41b31..66f3b2a593d 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -357,32 +357,10 @@ pub(crate) fn generator_send( } else if let Ok(call) = response.extract::>() { Ok(GeneratorResponse::NativeCall(call.take(py)?)) } else if let Ok(get_multi) = response.cast::() { - // Was an `All` or `concurrently`. Each item is one of: - // * a generator (async helper) — drive via generator_send; - // * a `Call` — dispatch directly via `gen_call`; - // * a `_Concurrently` — nested `concurrently(...)`. Treat its awaiter as a - // generator so the outer engine loop recurses into the inner tuple. let items = get_multi .try_iter()? - .map(|item| { - let item = item?; - if item.is_instance(&generator_type.as_py_type(py))? { - Ok(AllItem::Generator(Value::new(item.unbind()))) - } else if let Ok(call) = item.extract::>() { - call.take(py) - .map(AllItem::Call) - .map_err(PyValueError::new_err) - } else if let Ok(concurrently) = item.extract::>() { - let awaiter = Py::new(py, concurrently.awaiter(py))?; - Ok(AllItem::Generator(Value::new(awaiter.into_any()))) - } else { - Err(PyValueError::new_err(format!( - "Expected an `All` or `concurrently` to receive calls to rules, \ - but got: {response}" - ))) - } - }) - .collect::, _>>()?; + .map(|item| AllItem::parse(py, generator_type, item?)) + .collect::>>()?; Ok(GeneratorResponse::All(items)) } else { Err(PyValueError::new_err(format!( @@ -836,4 +814,31 @@ pub enum AllItem { Generator(Value), /// A direct `Call` returned by a call-by-name `@rule` invocation; execute it as-is. Call(Call), + /// A nested `concurrently(..)`: recursively join its items and wrap the results in a tuple. + Concurrent(Vec), +} + +impl AllItem { + /// Parse one item from an `All`/`concurrently(..)` tuple. Nested `_Concurrently` is + /// resolved recursively. + fn parse(py: Python<'_>, generator_type: &TypeId, item: Bound<'_, PyAny>) -> PyResult { + if item.is_instance(&generator_type.as_py_type(py))? { + Ok(Self::Generator(Value::new(item.unbind()))) + } else if let Ok(call) = item.extract::>() { + call.take(py).map(Self::Call).map_err(PyValueError::new_err) + } else if let Ok(concurrently) = item.extract::>() { + let items = concurrently + .calls + .bind(py) + .iter() + .map(|item| Self::parse(py, generator_type, item)) + .collect::>>()?; + Ok(Self::Concurrent(items)) + } else { + Err(PyValueError::new_err(format!( + "Expected an `All` or `concurrently` item to be a rule call, coroutine, or \ + nested `concurrently(...)`, but got: {item}" + ))) + } + } } diff --git a/src/rust/engine/src/nodes/task.rs b/src/rust/engine/src/nodes/task.rs index 00994a399a7..117e0f3d38c 100644 --- a/src/rust/engine/src/nodes/task.rs +++ b/src/rust/engine/src/nodes/task.rs @@ -121,6 +121,43 @@ impl Task { .boxed() } + // Dispatches a single `AllItem` to its corresponding driver, recursing into nested + // `concurrently(..)` groups and returning their joined results as a tuple `Value`. + fn gen_all_item( + context: &Context, + params: Params, + entry: Intern>, + item: externs::AllItem, + ) -> BoxFuture<'_, NodeResult> { + async move { + match item { + externs::AllItem::Generator(generator) => { + Self::gen_generator(context, params, entry, generator).await + } + externs::AllItem::Call(call) => Self::gen_call(context, params, entry, call).await, + externs::AllItem::Concurrent(items) => { + let values = Self::gen_all(context, params, entry, items).await?; + Python::attach(|py| externs::store_tuple(py, values)) + .map_err(|err| Python::attach(|py| Failure::from_py_err_with_gil(py, err))) + } + } + } + .boxed() + } + + // Concurrently drives every item in an `All`/`concurrently(..)`. + async fn gen_all( + context: &Context, + params: Params, + entry: Intern>, + items: Vec, + ) -> NodeResult> { + let futures = items + .into_iter() + .map(|item| Self::gen_all_item(context, params.clone(), entry, item)); + future::try_join_all(futures).await + } + /// /// Given a python generator Value, loop to request the generator's dependencies until /// it completes with a result Value or fails with an error. @@ -166,18 +203,7 @@ impl Task { } GeneratorResponse::All(items) => { let _blocking_token = workunit.blocking(); - let get_futures = items - .into_iter() - .map(|item| match item { - externs::AllItem::Generator(generator) => { - Self::gen_generator(context, params.clone(), entry, generator) - } - externs::AllItem::Call(call) => { - Self::gen_call(context, params.clone(), entry, call).boxed() - } - }) - .collect::>(); - match future::try_join_all(get_futures).await { + match Self::gen_all(context, params.clone(), entry, items).await { Ok(values) => { let values_tuple_result = Python::attach(|py| externs::store_tuple(py, values)); From 735751e21de565ff572bbf56cc19438ee7a765d3 Mon Sep 17 00:00:00 2001 From: Tobias Nilsson Date: Sat, 18 Apr 2026 13:48:12 +0200 Subject: [PATCH 8/8] feat: Remove PySequence roundtrip Makes _Concurrently and Call have the same shape --- .../pants/engine/internals/native_engine.pyi | 9 +- src/python/pants/testutil/rule_runner.py | 10 +- src/rust/engine/src/externs/mod.rs | 111 ++++++++++++------ src/rust/engine/src/nodes/task.rs | 4 +- 4 files changed, 90 insertions(+), 44 deletions(-) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index 1dbfccbc358..58081607279 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -1198,15 +1198,14 @@ class Call(Generic[_Output_co]): def __repr__(self) -> str: ... class _Concurrently(Generic[_Output_co]): - calls: tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...] def __init__( self, calls: tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...] ) -> None: ... - def __await__( + def __await__(self) -> Generator[_Concurrently[_Output_co], None, _Output_co]: ... + @property + def calls( self, - ) -> Generator[ - tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...], None, _Output_co - ]: ... + ) -> tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...]: ... class RuleCallTrampoline(Generic[_Output]): """The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so diff --git a/src/python/pants/testutil/rule_runner.py b/src/python/pants/testutil/rule_runner.py index 3b2bf18c611..b82945cb482 100644 --- a/src/python/pants/testutil/rule_runner.py +++ b/src/python/pants/testutil/rule_runner.py @@ -35,7 +35,11 @@ from pants.engine.fs import CreateDigest, Digest, FileContent, Snapshot, Workspace from pants.engine.goal import CurrentExecutingGoals, Goal from pants.engine.internals import native_engine, options_parsing -from pants.engine.internals.native_engine import ProcessExecutionEnvironment, PyExecutor +from pants.engine.internals.native_engine import ( + ProcessExecutionEnvironment, + PyExecutor, + _Concurrently, +) from pants.engine.internals.scheduler import ExecutionError, Scheduler, SchedulerSession from pants.engine.internals.selectors import Call, Params from pants.engine.internals.session import SessionValues @@ -710,7 +714,7 @@ def run_rule_with_mocks( unconsumed_mock_calls = set(mock_calls.keys()) - def get(res: Call): + def get(res: Any): if not isinstance(res, Call): raise AssertionError(f"Bad arg type: {res}") mock_call = mock_calls.get(res.rule_id) @@ -740,6 +744,8 @@ def warn_on_unconsumed_mocks(): res = rule_coroutine.send(rule_input) if isinstance(res, Call): rule_input = get(res) + elif isinstance(res, _Concurrently): + rule_input = [get(g) for g in res.calls] elif type(res) in (tuple, list): rule_input = [get(g) for g in res] else: diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 66f3b2a593d..14a97e45d8c 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -12,7 +12,7 @@ use pyo3::exceptions::{PyException, PyStopIteration, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::sync::{MutexExt, RwLockExt}; -use pyo3::types::{PyBool, PyBytes, PyDict, PySequence, PyString, PyTuple, PyType}; +use pyo3::types::{PyBool, PyBytes, PyDict, PyString, PyTuple, PyType}; use pyo3::{create_exception, import_exception, intern}; use smallvec::{SmallVec, smallvec}; use std::collections::BTreeMap; @@ -288,13 +288,12 @@ pub(crate) enum GeneratorInput { /// - coroutines may await: /// - `Call` /// - other coroutines, -/// - sequences of those types. +/// - a `_Concurrently` batch of the above. /// - we will `send` back a single value or tupled values to the coroutine, or `throw` an exception. /// - a coroutine will eventually return a single return value. /// pub(crate) fn generator_send( py: Python<'_>, - generator_type: &TypeId, generator: &Value, input: GeneratorInput, ) -> Result { @@ -356,12 +355,8 @@ pub(crate) fn generator_send( Ok(GeneratorResponse::Call(call.take(py)?)) } else if let Ok(call) = response.extract::>() { Ok(GeneratorResponse::NativeCall(call.take(py)?)) - } else if let Ok(get_multi) = response.cast::() { - let items = get_multi - .try_iter()? - .map(|item| AllItem::parse(py, generator_type, item?)) - .collect::>>()?; - Ok(GeneratorResponse::All(items)) + } else if let Ok(concurrently) = response.extract::>() { + Ok(GeneratorResponse::All(concurrently.take_items(py)?)) } else { Err(PyValueError::new_err(format!( "Async @rule error. Expected a rule call, but got: {response}" @@ -626,8 +621,9 @@ impl YieldOnce { } } -/// Yielded by `concurrently(...)` to hand a tuple of awaitables (each either a `Call` or a -/// coroutine from an async helper) to the engine for parallel execution. +/// Yielded by `concurrently(...)` to hand a batch of awaitables (each either a `Call`, a +/// coroutine from an async helper, or a nested `concurrently(..)`) to the engine for +/// parallel execution. #[pyclass( frozen, generic, @@ -635,24 +631,49 @@ impl YieldOnce { module = "pants.engine.internals.native_engine" )] pub struct PyConcurrently { - calls: Py, + items: Mutex>>, } impl PyConcurrently { - pub(crate) fn awaiter(&self, py: Python<'_>) -> YieldOnce { - YieldOnce::new(self.calls.clone_ref(py)) + pub(crate) fn take_items(&self, py: Python<'_>) -> PyResult> { + self.items + .lock_py_attached(py) + .take() + .ok_or_else(|| PyValueError::new_err("A `concurrently(...)` may only be awaited once.")) } } #[pymethods] impl PyConcurrently { #[new] - fn __new__(calls: Py) -> Self { - Self { calls } + fn __new__(py: Python<'_>, calls: Bound<'_, PyTuple>) -> PyResult { + let items = calls + .iter() + .map(|item| AllItem::parse(py, item)) + .collect::>>()?; + Ok(Self { + items: Mutex::new(Some(items)), + }) } - fn __await__(&self, py: Python<'_>) -> YieldOnce { - self.awaiter(py) + fn __await__(slf: Py) -> YieldOnce { + YieldOnce::new(slf) + } + + /// Rebuilds the original awaitables as fresh Python objects so test harnesses (e.g. + /// `run_rule_with_mocks`) can iterate and introspect without consuming the engine-side + /// items. Not used on the hot path. + #[getter] + fn calls<'py>(&self, py: Python<'py>) -> PyResult> { + let guard = self.items.lock_py_attached(py); + let items = guard.as_ref().ok_or_else(|| { + PyValueError::new_err("A `concurrently(...)` has already been awaited.") + })?; + let py_items = items + .iter() + .map(|item| item.to_python(py)) + .collect::>>()?; + PyTuple::new(py, py_items) } } @@ -763,7 +784,7 @@ pub struct NativeCall { pub call: BoxFuture<'static, Result>, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Call { pub rule_id: RuleId, pub output_type: TypeId, @@ -809,6 +830,7 @@ pub enum GeneratorResponse { All(Vec), } +#[derive(Clone)] pub enum AllItem { /// A generator produced by an `async def` rule helper; drive it via generator_send. Generator(Value), @@ -819,26 +841,47 @@ pub enum AllItem { } impl AllItem { - /// Parse one item from an `All`/`concurrently(..)` tuple. Nested `_Concurrently` is - /// resolved recursively. - fn parse(py: Python<'_>, generator_type: &TypeId, item: Bound<'_, PyAny>) -> PyResult { - if item.is_instance(&generator_type.as_py_type(py))? { - Ok(Self::Generator(Value::new(item.unbind()))) - } else if let Ok(call) = item.extract::>() { + fn parse(py: Python<'_>, item: Bound<'_, PyAny>) -> PyResult { + if let Ok(call) = item.extract::>() { call.take(py).map(Self::Call).map_err(PyValueError::new_err) } else if let Ok(concurrently) = item.extract::>() { - let items = concurrently - .calls - .bind(py) - .iter() - .map(|item| Self::parse(py, generator_type, item)) - .collect::>>()?; - Ok(Self::Concurrent(items)) + Ok(Self::Concurrent(concurrently.take_items(py)?)) + } else if item.is_instance(&COROUTINE_TYPE.as_py_type(py))? { + Ok(Self::Generator(Value::new(item.unbind()))) } else { Err(PyValueError::new_err(format!( - "Expected an `All` or `concurrently` item to be a rule call, coroutine, or \ - nested `concurrently(...)`, but got: {item}" + "Expected a `concurrently(..)` argument to be a rule call, coroutine, or \ + nested `concurrently(..)`, but got: {item}" ))) } } + + fn to_python<'py>(&self, py: Python<'py>) -> PyResult> { + match self { + Self::Generator(value) => Ok(value.bind(py).clone()), + Self::Call(call) => { + let py_call = PyGeneratorResponseCall(RwLock::new(Some(call.clone()))); + Ok(Py::new(py, py_call)?.into_bound(py).into_any()) + } + Self::Concurrent(items) => { + let py_concurrently = PyConcurrently { + items: Mutex::new(Some(items.clone())), + }; + Ok(Py::new(py, py_concurrently)?.into_bound(py).into_any()) + } + } + } } + +static COROUTINE_TYPE: LazyLock = LazyLock::new(|| { + Python::attach(|py| { + let coroutine_type = py + .import("types") + .expect("Failed to import `types`") + .getattr("CoroutineType") + .expect("`types.CoroutineType` is missing") + .cast_into::() + .expect("`types.CoroutineType` is not a type"); + TypeId::new(&coroutine_type) + }) +}); diff --git a/src/rust/engine/src/nodes/task.rs b/src/rust/engine/src/nodes/task.rs index 117e0f3d38c..081d8b9fa89 100644 --- a/src/rust/engine/src/nodes/task.rs +++ b/src/rust/engine/src/nodes/task.rs @@ -171,9 +171,7 @@ impl Task { ) -> NodeResult<(Value, TypeId)> { let mut input = GeneratorInput::Initial; loop { - let response = Python::attach(|py| { - externs::generator_send(py, &context.core.types.coroutine, &generator, input) - })?; + let response = Python::attach(|py| externs::generator_send(py, &generator, input))?; match response { GeneratorResponse::NativeCall(call) => { let _blocking_token = workunit.blocking();