diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index f9623e05826..58081607279 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -6,7 +6,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Coroutine, Generator, Iterable, Iterator, Mapping, Sequence from datetime import datetime from enum import Enum from io import RawIOBase @@ -1155,12 +1155,14 @@ def hash_prefix_zero_bits(item: str) -> int: ... # ------------------------------------------------------------------------------ _Output = TypeVar("_Output") +_Output_co = TypeVar("_Output_co", covariant=True) _Input = TypeVar("_Input") -class PyGeneratorResponseCall: +class Call(Generic[_Output_co]): rule_id: str output_type: type - inputs: Sequence[Any] + args: tuple[Any, ...] + implicit_args: dict[Any, type] @overload def __init__( @@ -1192,6 +1194,40 @@ class PyGeneratorResponseCall: input_arg0: type[_Input] | _Input, input_arg1: _Input | None = None, ) -> None: ... + def __await__(self) -> Generator[Any, None, _Output_co]: ... + def __repr__(self) -> str: ... + +class _Concurrently(Generic[_Output_co]): + def __init__( + self, calls: tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...] + ) -> None: ... + def __await__(self) -> Generator[_Concurrently[_Output_co], None, _Output_co]: ... + @property + def calls( + self, + ) -> tuple[Coroutine[Any, Any, Any] | Call[Any] | _Concurrently[Any], ...]: ... + +class RuleCallTrampoline(Generic[_Output]): + """The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so + each invocation constructs the already-awaitable `Call` directly. + `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. + """ + + rule_id: str + output_type: type[_Output] + rule: Any + __wrapped__: Callable[..., Any] + + def __init__( + self, + rule_id: str, + output_type: type[_Output], + wrapped: Callable[..., Any], + rule: Any, + ) -> None: ... + def __call__( + self, *args: Any, __implicitly: Sequence[Any] = (), **kwargs: Any + ) -> Call[_Output]: ... # ------------------------------------------------------------------------------ # (uncategorized) diff --git a/src/python/pants/engine/internals/rule_visitor.py b/src/python/pants/engine/internals/rule_visitor.py index 2d00f198019..346061afe12 100644 --- a/src/python/pants/engine/internals/rule_visitor.py +++ b/src/python/pants/engine/internals/rule_visitor.py @@ -16,6 +16,7 @@ import typing_extensions from pants.base.exceptions import RuleTypeError +from pants.engine.internals.native_engine import RuleCallTrampoline from pants.engine.internals.selectors import ( AwaitableConstraints, concurrently, @@ -177,6 +178,11 @@ def _lookup_return_type(func: Callable, check: bool = False) -> Any: class _AwaitableCollector(ast.NodeVisitor): def __init__(self, func: Callable): + # `func` may be a RuleCallTrampoline (the return value of an `@rule`-decorated + # function). `inspect.getsource` and friends only know about real Python functions, + # so follow `__wrapped__` to reach the underlying implementation. + if isinstance(func, RuleCallTrampoline): + func = func.__wrapped__ self.func = func source = inspect.getsource(func) or "" beginning_indent = _get_starting_indent(source) @@ -314,9 +320,9 @@ def _get_byname_awaitable( def visit_Call(self, call_node: ast.Call) -> None: func = self._lookup(call_node.func) if func is not None: - if (inspect.isfunction(func) or isinstance(func, RuleDescriptor)) and ( - rule_id := getattr(func, "rule_id", None) - ) is not None: + if ( + inspect.isfunction(func) or isinstance(func, (RuleDescriptor, RuleCallTrampoline)) + ) and (rule_id := getattr(func, "rule_id", None)) is not None: # Is a direct `@rule` call. self.awaitables.append(self._get_byname_awaitable(rule_id, func, call_node)) elif inspect.iscoroutinefunction(func): diff --git a/src/python/pants/engine/internals/scheduler_test.py b/src/python/pants/engine/internals/scheduler_test.py index 6f66f31e5d3..2e3f4711584 100644 --- a/src/python/pants/engine/internals/scheduler_test.py +++ b/src/python/pants/engine/internals/scheduler_test.py @@ -474,12 +474,6 @@ def test_trace_includes_nested_exception_traceback() -> None: File LOCATION-INFO, in catch_and_reraise return await raise_an_exception(SomeInput(outer_input.s)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File LOCATION-INFO, in wrapper - return await call - ^^^^^^^^^^ - File LOCATION-INFO, in __await__ - result = yield self - ^^^^^^^^^^ File LOCATION-INFO, in raise_an_exception raise Exception(some_input.s) Exception: asdf diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index 236adcbb3f4..2aa47a9f250 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -4,11 +4,12 @@ from __future__ import annotations import itertools -from collections.abc import Coroutine, Generator, Iterable +from collections.abc import Coroutine, Iterable from dataclasses import dataclass -from typing import Any, TypeVar, cast, overload +from typing import Any, TypeVar, overload -from pants.engine.internals.native_engine import PyGeneratorResponseCall +from pants.engine.internals.native_engine import Call as Call # noqa: F401 +from pants.engine.internals.native_engine import _Concurrently from pants.util.strutil import softwrap _Output = TypeVar("_Output") @@ -30,26 +31,6 @@ def __str__(self) -> str: return repr(self) -class Call(PyGeneratorResponseCall): - def __await__( - self, - ) -> Generator[Any, None, Any]: - result = yield self - return result - - def __repr__(self) -> str: - return f"Call({self.rule_id}(...) -> {self.output_type.__name__})" - - -@dataclass(frozen=True) -class _Concurrently: - calls: tuple[Coroutine, ...] - - def __await__(self) -> Generator[tuple[Coroutine, ...], None, tuple]: - result = yield self.calls - return cast(tuple, result) - - # These type variables are used to parametrize from 1 to 10 args when used in a tuple-style # concurrently() call. @@ -66,140 +47,145 @@ def __await__(self) -> Generator[tuple[Coroutine, ...], None, tuple]: @overload -async def Concurrently( - __gets: Iterable[Coroutine[Any, Any, _Output]], -) -> tuple[_Output, ...]: ... +def Concurrently( + __gets: Iterable[Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output]], +) -> _Concurrently[tuple[_Output, ...]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Output], - __get1: Coroutine[Any, Any, _Output], - __get2: Coroutine[Any, Any, _Output], - __get3: Coroutine[Any, Any, _Output], - __get4: Coroutine[Any, Any, _Output], - __get5: Coroutine[Any, Any, _Output], - __get6: Coroutine[Any, Any, _Output], - __get7: Coroutine[Any, Any, _Output], - __get8: Coroutine[Any, Any, _Output], - __get9: Coroutine[Any, Any, _Output], - __get10: Coroutine[Any, Any, _Output], - *__gets: Coroutine[Any, Any, _Output], -) -> tuple[_Output, ...]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get1: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get2: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get3: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get4: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get5: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get6: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get7: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get8: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get9: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + __get10: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], + *__gets: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], +) -> _Concurrently[tuple[_Output, ...]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], - __get7: Coroutine[Any, Any, _Out7], - __get8: Coroutine[Any, Any, _Out8], - __get9: Coroutine[Any, Any, _Out9], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7], + __get8: Coroutine[Any, Any, _Out8] | Call[_Out8] | _Concurrently[_Out8], + __get9: Coroutine[Any, Any, _Out9] | Call[_Out9] | _Concurrently[_Out9], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], - __get7: Coroutine[Any, Any, _Out7], - __get8: Coroutine[Any, Any, _Out8], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7], + __get8: Coroutine[Any, Any, _Out8] | Call[_Out8] | _Concurrently[_Out8], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], - __get7: Coroutine[Any, Any, _Out7], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], + __get7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], - __get6: Coroutine[Any, Any, _Out6], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], + __get6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], - __get5: Coroutine[Any, Any, _Out5], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], + __get5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], - __get4: Coroutine[Any, Any, _Out4], -) -> tuple[_Out0, _Out1, _Out2, _Out3, _Out4]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], + __get4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3, _Out4]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], - __get3: Coroutine[Any, Any, _Out3], -) -> tuple[_Out0, _Out1, _Out2, _Out3]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], + __get3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2, _Out3]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], - __get2: Coroutine[Any, Any, _Out2], -) -> tuple[_Out0, _Out1, _Out2]: ... +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], + __get2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2], +) -> _Concurrently[tuple[_Out0, _Out1, _Out2]]: ... @overload -async def Concurrently( - __get0: Coroutine[Any, Any, _Out0], - __get1: Coroutine[Any, Any, _Out1], -) -> tuple[_Out0, _Out1]: ... - - -async def Concurrently( - __arg0: (Iterable[Coroutine[Any, Any, _Output]] | Coroutine[Any, Any, _Out0]), - __arg1: Coroutine[Any, Any, _Out1] | None = None, - __arg2: Coroutine[Any, Any, _Out2] | None = None, - __arg3: Coroutine[Any, Any, _Out3] | None = None, - __arg4: Coroutine[Any, Any, _Out4] | None = None, - __arg5: Coroutine[Any, Any, _Out5] | None = None, - __arg6: Coroutine[Any, Any, _Out6] | None = None, - __arg7: Coroutine[Any, Any, _Out7] | None = None, - __arg8: Coroutine[Any, Any, _Out8] | None = None, - __arg9: Coroutine[Any, Any, _Out9] | None = None, - *__args: Coroutine[Any, Any, _Output], -) -> ( +def Concurrently( + __get0: Coroutine[Any, Any, _Out0] | Call[_Out0] | _Concurrently[_Out0], + __get1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1], +) -> _Concurrently[tuple[_Out0, _Out1]]: ... + + +def Concurrently( + __arg0: ( + Iterable[Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output]] + | Coroutine[Any, Any, _Out0] + | Call[_Out0] + | _Concurrently[_Out0] + ), + __arg1: Coroutine[Any, Any, _Out1] | Call[_Out1] | _Concurrently[_Out1] | None = None, + __arg2: Coroutine[Any, Any, _Out2] | Call[_Out2] | _Concurrently[_Out2] | None = None, + __arg3: Coroutine[Any, Any, _Out3] | Call[_Out3] | _Concurrently[_Out3] | None = None, + __arg4: Coroutine[Any, Any, _Out4] | Call[_Out4] | _Concurrently[_Out4] | None = None, + __arg5: Coroutine[Any, Any, _Out5] | Call[_Out5] | _Concurrently[_Out5] | None = None, + __arg6: Coroutine[Any, Any, _Out6] | Call[_Out6] | _Concurrently[_Out6] | None = None, + __arg7: Coroutine[Any, Any, _Out7] | Call[_Out7] | _Concurrently[_Out7] | None = None, + __arg8: Coroutine[Any, Any, _Out8] | Call[_Out8] | _Concurrently[_Out8] | None = None, + __arg9: Coroutine[Any, Any, _Out9] | Call[_Out9] | _Concurrently[_Out9] | None = None, + *__args: Coroutine[Any, Any, _Output] | Call[_Output] | _Concurrently[_Output], +) -> _Concurrently[ tuple[_Output, ...] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8, _Out9] | tuple[_Out0, _Out1, _Out2, _Out3, _Out4, _Out5, _Out6, _Out7, _Out8] @@ -211,7 +197,7 @@ async def Concurrently( | tuple[_Out0, _Out1, _Out2] | tuple[_Out0, _Out1] | tuple[_Out0] -): +]: """Yield a tuple of Coroutine instances all at once. The `yield`ed value `self.calls` is interpreted by the engine within @@ -235,10 +221,10 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently(tuple(__arg0)) + return _Concurrently(tuple(__arg0)) if ( - isinstance(__arg0, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) and __arg1 is None and __arg2 is None and __arg3 is None @@ -250,11 +236,11 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0,)) + return _Concurrently((__arg0,)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) and __arg2 is None and __arg3 is None and __arg4 is None @@ -265,12 +251,12 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1)) + return _Concurrently((__arg0, __arg1)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) and __arg3 is None and __arg4 is None and __arg5 is None @@ -280,13 +266,13 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2)) + return _Concurrently((__arg0, __arg1, __arg2)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) and __arg4 is None and __arg5 is None and __arg6 is None @@ -295,14 +281,14 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) and __arg5 is None and __arg6 is None and __arg7 is None @@ -310,84 +296,84 @@ async def Concurrently( and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) and __arg6 is None and __arg7 is None and __arg8 is None and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) and __arg7 is None and __arg8 is None and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) - and isinstance(__arg7, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) + and isinstance(__arg7, (Coroutine, Call, _Concurrently)) and __arg8 is None and __arg9 is None and not __args ): - return await _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7)) + return _Concurrently((__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7)) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) - and isinstance(__arg7, Coroutine) - and isinstance(__arg8, Coroutine) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) + and isinstance(__arg7, (Coroutine, Call, _Concurrently)) + and isinstance(__arg8, (Coroutine, Call, _Concurrently)) and __arg9 is None and not __args ): - return await _Concurrently( + return _Concurrently( (__arg0, __arg1, __arg2, __arg3, __arg4, __arg5, __arg6, __arg7, __arg8) ) if ( - isinstance(__arg0, Coroutine) - and isinstance(__arg1, Coroutine) - and isinstance(__arg2, Coroutine) - and isinstance(__arg3, Coroutine) - and isinstance(__arg4, Coroutine) - and isinstance(__arg5, Coroutine) - and isinstance(__arg6, Coroutine) - and isinstance(__arg7, Coroutine) - and isinstance(__arg8, Coroutine) - and isinstance(__arg9, Coroutine) - and all(isinstance(arg, Coroutine) for arg in __args) + isinstance(__arg0, (Coroutine, Call, _Concurrently)) + and isinstance(__arg1, (Coroutine, Call, _Concurrently)) + and isinstance(__arg2, (Coroutine, Call, _Concurrently)) + and isinstance(__arg3, (Coroutine, Call, _Concurrently)) + and isinstance(__arg4, (Coroutine, Call, _Concurrently)) + and isinstance(__arg5, (Coroutine, Call, _Concurrently)) + and isinstance(__arg6, (Coroutine, Call, _Concurrently)) + and isinstance(__arg7, (Coroutine, Call, _Concurrently)) + and isinstance(__arg8, (Coroutine, Call, _Concurrently)) + and isinstance(__arg9, (Coroutine, Call, _Concurrently)) + and all(isinstance(arg, (Coroutine, Call, _Concurrently)) for arg in __args) ): - return await _Concurrently( + return _Concurrently( ( __arg0, __arg1, diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index a17e757a860..fd03418d6d7 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -3,10 +3,9 @@ from __future__ import annotations -import functools import inspect import sys -from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence +from collections.abc import Callable, Coroutine, Iterable, Mapping from dataclasses import dataclass from enum import Enum from types import FrameType, ModuleType @@ -25,8 +24,9 @@ from typing_extensions import ParamSpec from pants.engine.engine_aware import SideEffecting +from pants.engine.internals.native_engine import Call, RuleCallTrampoline from pants.engine.internals.rule_visitor import collect_awaitables -from pants.engine.internals.selectors import AwaitableConstraints, Call +from pants.engine.internals.selectors import AwaitableConstraints from pants.engine.internals.selectors import concurrently as concurrently # noqa: F401 from pants.engine.unions import UnionRule from pants.util.frozendict import FrozenDict @@ -54,18 +54,7 @@ class RuleType(Enum): R = TypeVar("R") SyncRuleT = Callable[P, R] AsyncRuleT = Callable[P, Coroutine[Any, Any, R]] -RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], AsyncRuleT] - - -def _rule_call_trampoline( - rule_id: str, output_type: type[Any], func: Callable[P, R] -) -> Callable[P, R]: - @functools.wraps(func) # type: ignore - async def wrapper(*args, __implicitly: Sequence[Any] = (), **kwargs): - call = Call(rule_id, output_type, args, *__implicitly) - return await call - - return cast(Callable[P, R], wrapper) +RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], Callable[P, Call[R]]] def _make_rule( @@ -112,13 +101,12 @@ def wrapper(original_func): awaitables = FrozenOrderedSet(collect_awaitables(original_func)) validate_requirements(func_id, parameter_types, awaitables, cacheable) - func = _rule_call_trampoline(canonical_name, return_type, original_func) - # NB: The named definition of the rule ends up wrapped in a trampoline to handle memoization - # and implicit arguments for direct by-name calls. But the `TaskRule` takes a reference to - # the original unwrapped function, which avoids the need for a special protocol when the - # engine invokes a @rule under memoization. - func.rule = TaskRule( + # NB: The named definition of the rule ends up wrapped in a trampoline to handle + # implicit arguments for direct by-name calls. The `TaskRule` takes a reference to + # the original unwrapped function, which avoids the need for a special protocol when + # the engine invokes a @rule under memoization. + task_rule = TaskRule( return_type, FrozenDict(parameter_types), awaitables, @@ -130,7 +118,10 @@ def wrapper(original_func): cacheable=cacheable, polymorphic=polymorphic, ) - return func + return cast( + Callable[P, Call[R]], + RuleCallTrampoline(canonical_name, return_type, original_func, task_rule), + ) return wrapper @@ -262,7 +253,7 @@ class _RuleDecoratorKwargs(RuleDecoratorKwargs): def rule_decorator( func: SyncRuleT | AsyncRuleT, **kwargs: Unpack[_RuleDecoratorKwargs] -) -> AsyncRuleT: +) -> Callable[P, Call[R]]: if not inspect.isfunction(func): raise ValueError("The @rule decorator expects to be placed on a function.") @@ -397,7 +388,7 @@ def validate_requirements( ) -def inner_rule(*args, **kwargs) -> AsyncRuleT | RuleDecorator: +def inner_rule(*args, **kwargs) -> Callable[P, Call[R]] | RuleDecorator: if len(args) == 1 and inspect.isfunction(args[0]): return rule_decorator(*args, **kwargs) else: diff --git a/src/python/pants/testutil/rule_runner.py b/src/python/pants/testutil/rule_runner.py index af776ca8055..b82945cb482 100644 --- a/src/python/pants/testutil/rule_runner.py +++ b/src/python/pants/testutil/rule_runner.py @@ -35,7 +35,11 @@ from pants.engine.fs import CreateDigest, Digest, FileContent, Snapshot, Workspace from pants.engine.goal import CurrentExecutingGoals, Goal from pants.engine.internals import native_engine, options_parsing -from pants.engine.internals.native_engine import ProcessExecutionEnvironment, PyExecutor +from pants.engine.internals.native_engine import ( + ProcessExecutionEnvironment, + PyExecutor, + _Concurrently, +) from pants.engine.internals.scheduler import ExecutionError, Scheduler, SchedulerSession from pants.engine.internals.selectors import Call, Params from pants.engine.internals.session import SessionValues @@ -710,34 +714,20 @@ def run_rule_with_mocks( unconsumed_mock_calls = set(mock_calls.keys()) - def get(res: Call | Coroutine): - if isinstance(res, Coroutine): - # A call-by-name element in a concurrently() is a Coroutine whose frame is - # the trampoline wrapper that creates and immediately awaits the Call. - locals = inspect.getcoroutinelocals(res) - assert locals is not None - rule_id = locals["rule_id"] - args = locals["args"] - kwargs = dict(locals["kwargs"]) - __implicitly = locals.get("__implicitly") - if __implicitly: - kwargs["__implicitly"] = __implicitly - mock_call = mock_calls.get(rule_id) - if mock_call: - unconsumed_mock_calls.discard(rule_id) - # Close the original, unmocked, coroutine, to prevent the "was never awaited" - # warning polluting stderr data that the test may examine. - res.close() - return mock_call(*args, **kwargs) - raise AssertionError(f"No mock_call provided for {rule_id}.") - elif isinstance(res, Call): - mock_call = mock_calls.get(res.rule_id) - if mock_call: - unconsumed_mock_calls.discard(res.rule_id) - return mock_call(*res.inputs) - raise AssertionError(f"No mock_call provided for {res.rule_id}.") - else: + def get(res: Any): + if not isinstance(res, Call): raise AssertionError(f"Bad arg type: {res}") + mock_call = mock_calls.get(res.rule_id) + if mock_call is None: + raise AssertionError(f"No mock_call provided for {res.rule_id}.") + unconsumed_mock_calls.discard(res.rule_id) + # NB: if the mock declares an `__implicitly` parameter, forward the raw `(dict,)` so it + # can inspect declared types (e.g. to route polymorphic dispatch); otherwise unpack the + # implicit values positionally. + implicit = res.implicit_args + if implicit and "__implicitly" in inspect.signature(mock_call).parameters: + return mock_call(*res.args, __implicitly=(implicit,)) + return mock_call(*res.args, *implicit) rule_coroutine = res rule_input = None @@ -754,6 +744,8 @@ def warn_on_unconsumed_mocks(): res = rule_coroutine.send(rule_input) if isinstance(res, Call): rule_input = get(res) + elif isinstance(res, _Concurrently): + rule_input = [get(g) for g in res.calls] elif type(res) in (tuple, list): rule_input = [get(g) for g in res] else: diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 51e4e352685..14a97e45d8c 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -10,8 +10,9 @@ use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; use pyo3::FromPyObject; use pyo3::exceptions::{PyException, PyStopIteration, PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; use pyo3::sync::{MutexExt, RwLockExt}; -use pyo3::types::{PyBool, PyBytes, PyDict, PySequence, PyString, PyTuple, PyType}; +use pyo3::types::{PyBool, PyBytes, PyDict, PyString, PyTuple, PyType}; use pyo3::{create_exception, import_exception, intern}; use smallvec::{SmallVec, smallvec}; use std::collections::BTreeMap; @@ -50,6 +51,8 @@ pub fn register(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add("EngineError", py.get_type::())?; m.add("IntrinsicError", py.get_type::())?; @@ -285,13 +288,12 @@ pub(crate) enum GeneratorInput { /// - coroutines may await: /// - `Call` /// - other coroutines, -/// - sequences of those types. +/// - a `_Concurrently` batch of the above. /// - we will `send` back a single value or tupled values to the coroutine, or `throw` an exception. /// - a coroutine will eventually return a single return value. /// pub(crate) fn generator_send( py: Python<'_>, - generator_type: &TypeId, generator: &Value, input: GeneratorInput, ) -> Result { @@ -353,25 +355,8 @@ pub(crate) fn generator_send( Ok(GeneratorResponse::Call(call.take(py)?)) } else if let Ok(call) = response.extract::>() { Ok(GeneratorResponse::NativeCall(call.take(py)?)) - } else if let Ok(get_multi) = response.cast::() { - // Was an `All` or `concurrently`. - let generators = get_multi - .try_iter()? - .map(|generator| { - let generator = generator?; - // TODO: Find a better way to check whether something is a coroutine... this seems - // unnecessarily awkward. - if generator.is_instance(&generator_type.as_py_type(py))? { - Ok(Value::new(generator.unbind())) - } else { - Err(PyValueError::new_err(format!( - "Expected an `All` or `concurrently` to receive calls to rules, \ - but got: {response}" - ))) - } - }) - .collect::, _>>()?; - Ok(GeneratorResponse::All(generators)) + } else if let Ok(concurrently) = response.extract::>() { + Ok(GeneratorResponse::All(concurrently.take_items(py)?)) } else { Err(PyValueError::new_err(format!( "Async @rule error. Expected a rule call, but got: {response}" @@ -476,7 +461,11 @@ impl PyGeneratorResponseNativeCall { } } -#[pyclass(subclass)] +#[pyclass( + generic, + name = "Call", + module = "pants.engine.internals.native_engine" +)] pub struct PyGeneratorResponseCall(RwLock>); impl PyGeneratorResponseCall { @@ -491,15 +480,10 @@ impl PyGeneratorResponseCall { )) } } -} -#[pymethods] -impl PyGeneratorResponseCall { - #[new] - #[pyo3(signature = (rule_id, output_type, args, input_arg0=None))] - fn __new__( - py: Python, - rule_id: String, + pub(crate) fn construct( + py: Python<'_>, + rule_id: &str, output_type: &Bound<'_, PyType>, args: &Bound<'_, PyTuple>, input_arg0: Option>, @@ -521,7 +505,7 @@ impl PyGeneratorResponseCall { let (input_types, inputs) = interpret_implicit_args(py, input_arg0)?; Ok(Self(RwLock::new(Some(Call { - rule_id: RuleId::from_string(rule_id), + rule_id: RuleId::new(rule_id), output_type, args, args_arity, @@ -529,6 +513,21 @@ impl PyGeneratorResponseCall { inputs, })))) } +} + +#[pymethods] +impl PyGeneratorResponseCall { + #[new] + #[pyo3(signature = (rule_id, output_type, args, input_arg0=None))] + fn __new__( + py: Python, + rule_id: PyBackedStr, + output_type: &Bound<'_, PyType>, + args: &Bound<'_, PyTuple>, + input_arg0: Option>, + ) -> PyResult { + Self::construct(py, &rule_id, output_type, args, input_arg0) + } #[getter] fn rule_id(&self, py: Python) -> PyResult { @@ -544,16 +543,137 @@ impl PyGeneratorResponseCall { } #[getter] - fn inputs(&self, py: Python<'_>) -> PyResult>> { + fn args<'py>(&self, py: Python<'py>) -> PyResult> { let inner = self.borrow_inner(py)?; - let args: Vec> = inner.args.as_ref().map_or_else( - || Ok(Vec::default()), - |args| args.to_py_object().extract(py), - )?; - Ok(args - .into_iter() - .chain(inner.inputs.iter().map(Key::to_py_object)) - .collect()) + match &inner.args { + Some(args) => Ok(args.to_py_object().extract(py)?), + None => Ok(PyTuple::empty(py)), + } + } + + /// NB: keyed on the value, mirroring the `{value: declared_type}` shape of `implicitly(...)` + /// at the call site so test mocks can read declared types back off the dict. + #[getter] + fn implicit_args<'py>(&self, py: Python<'py>) -> PyResult> { + let inner = self.borrow_inner(py)?; + let d = PyDict::new(py); + for (typ, val) in inner.input_types.iter().zip(inner.inputs.iter()) { + d.set_item(val.to_py_object(), typ.as_py_type(py))?; + } + Ok(d) + } + + fn __await__(slf: Py) -> YieldOnce { + YieldOnce::new(slf) + } + + fn __repr__(&self, py: Python) -> PyResult { + let inner = self.borrow_inner(py)?; + let output_type_name = inner.output_type.as_py_type(py).name()?.to_string(); + Ok(format!( + "Call({}(...) -> {output_type_name})", + inner.rule_id + )) + } +} + +/// A generator-protocol iterator that yields one value to the engine on the first `send`, +/// then raises `StopIteration(result)` when the engine sends the result back. Used to +/// implement `Call.__await__` and `_Concurrently.__await__`. +#[pyclass(frozen, module = "pants.engine.internals.native_engine")] +pub struct YieldOnce(Mutex>>); + +impl YieldOnce { + fn new(value: Py) -> Self { + Self(Mutex::new(Some(value.into_any()))) + } +} + +#[pymethods] +impl YieldOnce { + fn __iter__(slf: Py) -> Py { + slf + } + + fn __next__(&self, py: Python<'_>) -> PyResult> { + self.send(py, None) + } + + #[pyo3(signature = (value=None))] + fn send(&self, py: Python<'_>, value: Option>) -> PyResult> { + let mut state = self.0.lock_py_attached(py); + if state.is_some() && matches!(&value, Some(v) if !v.is_none(py)) { + return Err(PyTypeError::new_err( + "can't send non-None value to a just-started generator", + )); + } + state + .take() + .ok_or_else(|| PyStopIteration::new_err((value.unwrap_or_else(|| py.None()),))) + } + + fn throw(&self, exc: Bound<'_, PyAny>) -> PyResult> { + Err(PyErr::from_value(exc)) + } + + fn close(&self, py: Python<'_>) { + *self.0.lock_py_attached(py) = None; + } +} + +/// Yielded by `concurrently(...)` to hand a batch of awaitables (each either a `Call`, a +/// coroutine from an async helper, or a nested `concurrently(..)`) to the engine for +/// parallel execution. +#[pyclass( + frozen, + generic, + name = "_Concurrently", + module = "pants.engine.internals.native_engine" +)] +pub struct PyConcurrently { + items: Mutex>>, +} + +impl PyConcurrently { + pub(crate) fn take_items(&self, py: Python<'_>) -> PyResult> { + self.items + .lock_py_attached(py) + .take() + .ok_or_else(|| PyValueError::new_err("A `concurrently(...)` may only be awaited once.")) + } +} + +#[pymethods] +impl PyConcurrently { + #[new] + fn __new__(py: Python<'_>, calls: Bound<'_, PyTuple>) -> PyResult { + let items = calls + .iter() + .map(|item| AllItem::parse(py, item)) + .collect::>>()?; + Ok(Self { + items: Mutex::new(Some(items)), + }) + } + + fn __await__(slf: Py) -> YieldOnce { + YieldOnce::new(slf) + } + + /// Rebuilds the original awaitables as fresh Python objects so test harnesses (e.g. + /// `run_rule_with_mocks`) can iterate and introspect without consuming the engine-side + /// items. Not used on the hot path. + #[getter] + fn calls<'py>(&self, py: Python<'py>) -> PyResult> { + let guard = self.items.lock_py_attached(py); + let items = guard.as_ref().ok_or_else(|| { + PyValueError::new_err("A `concurrently(...)` has already been awaited.") + })?; + let py_items = items + .iter() + .map(|item| item.to_python(py)) + .collect::>>()?; + PyTuple::new(py, py_items) } } @@ -566,11 +686,105 @@ impl PyGeneratorResponseCall { } } +/// The callable `@rule` returns. Captures `rule_id` and `output_type` at decoration time so +/// each invocation constructs the already-awaitable `Call` directly. +/// `__getattribute__` forwards `__doc__` and other introspection attrs to the wrapped function. +#[pyclass(frozen, generic, module = "pants.engine.internals.native_engine")] +pub struct RuleCallTrampoline { + rule_id: PyBackedStr, + #[pyo3(get)] + output_type: Py, + #[pyo3(get, name = "__wrapped__")] + wrapped: Py, + #[pyo3(get)] + rule: Py, +} + +#[pymethods] +impl RuleCallTrampoline { + #[new] + fn __new__( + rule_id: PyBackedStr, + output_type: Py, + wrapped: Py, + rule: Py, + ) -> Self { + Self { + rule_id, + output_type, + wrapped, + rule, + } + } + + #[getter] + fn rule_id(&self) -> &PyBackedStr { + &self.rule_id + } + + #[pyo3(signature = (*args, __implicitly=None, **_kwargs))] + fn __call__( + &self, + py: Python<'_>, + args: &Bound<'_, PyTuple>, + __implicitly: Option<&Bound<'_, PyTuple>>, + _kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult> { + let input_arg0 = match __implicitly { + Some(t) if !t.is_empty() => Some(t.get_item(0)?), + _ => None, + }; + let call = PyGeneratorResponseCall::construct( + py, + &self.rule_id, + self.output_type.bind(py), + args, + input_arg0, + )?; + Py::new(py, call) + } + + /// Forward unknown attribute lookups (`__name__`, `__qualname__`, `__module__`, + /// `__line_number__`, custom attrs, ...) to the wrapped function so introspection that + /// would've relied on `functools.wraps` still works without a per-instance `__dict__`. + fn __getattr__<'py>(&self, py: Python<'py>, name: &str) -> PyResult> { + self.wrapped.bind(py).getattr(name) + } + + /// `__doc__` lives in the type object's `tp_doc` slot, which shadows `#[getter]` and + /// `__getattr__`. Intercepting at `tp_getattro` via `__getattribute__` is the only hook + /// that fires before the slot is read. See PyO3/pyo3#2187. + fn __getattribute__<'py>( + slf: PyRef<'py, Self>, + name: Bound<'py, PyString>, + ) -> PyResult> { + let py = slf.py(); + if name.to_cow()? == "__doc__" { + return slf.wrapped.bind(py).getattr(intern!(py, "__doc__")); + } + unsafe { + Bound::from_owned_ptr_or_err( + py, + pyo3::ffi::PyObject_GenericGetAttr(slf.as_ptr(), name.as_ptr()), + ) + } + } + + fn __repr__(&self, py: Python) -> PyResult { + let name: String = self + .wrapped + .bind(py) + .getattr(intern!(py, "__name__"))? + .extract()?; + Ok(format!("")) + } +} + pub struct NativeCall { pub call: BoxFuture<'static, Result>, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Call { pub rule_id: RuleId, pub output_type: TypeId, @@ -609,10 +823,65 @@ pub enum GeneratorResponse { NativeCall(NativeCall), /// The generator is awaiting a call to a known rule. Call(Call), - /// The generator is awaiting calls to a series of generators, all of which will - /// produce `Call`s. + /// The generator is awaiting completion of a series of awaitables. /// - /// The generators used in this position will either be call-by-name `@rule` stubs (which will - /// immediately produce a `Call`, and then return its value), or async "rule helpers". - All(Vec), + /// Each entry is either a rule `Call` (from a direct `RuleCallTrampoline` invocation, no + /// coroutine wrapper) or a generator from an async rule helper. + All(Vec), } + +#[derive(Clone)] +pub enum AllItem { + /// A generator produced by an `async def` rule helper; drive it via generator_send. + Generator(Value), + /// A direct `Call` returned by a call-by-name `@rule` invocation; execute it as-is. + Call(Call), + /// A nested `concurrently(..)`: recursively join its items and wrap the results in a tuple. + Concurrent(Vec), +} + +impl AllItem { + fn parse(py: Python<'_>, item: Bound<'_, PyAny>) -> PyResult { + if let Ok(call) = item.extract::>() { + call.take(py).map(Self::Call).map_err(PyValueError::new_err) + } else if let Ok(concurrently) = item.extract::>() { + Ok(Self::Concurrent(concurrently.take_items(py)?)) + } else if item.is_instance(&COROUTINE_TYPE.as_py_type(py))? { + Ok(Self::Generator(Value::new(item.unbind()))) + } else { + Err(PyValueError::new_err(format!( + "Expected a `concurrently(..)` argument to be a rule call, coroutine, or \ + nested `concurrently(..)`, but got: {item}" + ))) + } + } + + fn to_python<'py>(&self, py: Python<'py>) -> PyResult> { + match self { + Self::Generator(value) => Ok(value.bind(py).clone()), + Self::Call(call) => { + let py_call = PyGeneratorResponseCall(RwLock::new(Some(call.clone()))); + Ok(Py::new(py, py_call)?.into_bound(py).into_any()) + } + Self::Concurrent(items) => { + let py_concurrently = PyConcurrently { + items: Mutex::new(Some(items.clone())), + }; + Ok(Py::new(py, py_concurrently)?.into_bound(py).into_any()) + } + } + } +} + +static COROUTINE_TYPE: LazyLock = LazyLock::new(|| { + Python::attach(|py| { + let coroutine_type = py + .import("types") + .expect("Failed to import `types`") + .getattr("CoroutineType") + .expect("`types.CoroutineType` is missing") + .cast_into::() + .expect("`types.CoroutineType` is not a type"); + TypeId::new(&coroutine_type) + }) +}); diff --git a/src/rust/engine/src/nodes/task.rs b/src/rust/engine/src/nodes/task.rs index 3e4b4a374f0..081d8b9fa89 100644 --- a/src/rust/engine/src/nodes/task.rs +++ b/src/rust/engine/src/nodes/task.rs @@ -121,6 +121,43 @@ impl Task { .boxed() } + // Dispatches a single `AllItem` to its corresponding driver, recursing into nested + // `concurrently(..)` groups and returning their joined results as a tuple `Value`. + fn gen_all_item( + context: &Context, + params: Params, + entry: Intern>, + item: externs::AllItem, + ) -> BoxFuture<'_, NodeResult> { + async move { + match item { + externs::AllItem::Generator(generator) => { + Self::gen_generator(context, params, entry, generator).await + } + externs::AllItem::Call(call) => Self::gen_call(context, params, entry, call).await, + externs::AllItem::Concurrent(items) => { + let values = Self::gen_all(context, params, entry, items).await?; + Python::attach(|py| externs::store_tuple(py, values)) + .map_err(|err| Python::attach(|py| Failure::from_py_err_with_gil(py, err))) + } + } + } + .boxed() + } + + // Concurrently drives every item in an `All`/`concurrently(..)`. + async fn gen_all( + context: &Context, + params: Params, + entry: Intern>, + items: Vec, + ) -> NodeResult> { + let futures = items + .into_iter() + .map(|item| Self::gen_all_item(context, params.clone(), entry, item)); + future::try_join_all(futures).await + } + /// /// Given a python generator Value, loop to request the generator's dependencies until /// it completes with a result Value or fails with an error. @@ -134,9 +171,7 @@ impl Task { ) -> NodeResult<(Value, TypeId)> { let mut input = GeneratorInput::Initial; loop { - let response = Python::attach(|py| { - externs::generator_send(py, &context.core.types.coroutine, &generator, input) - })?; + let response = Python::attach(|py| externs::generator_send(py, &generator, input))?; match response { GeneratorResponse::NativeCall(call) => { let _blocking_token = workunit.blocking(); @@ -164,15 +199,9 @@ impl Task { Err(failure) => break Err(failure), } } - GeneratorResponse::All(generators) => { + GeneratorResponse::All(items) => { let _blocking_token = workunit.blocking(); - let get_futures = generators - .into_iter() - .map(|generator| { - Self::gen_generator(context, params.clone(), entry, generator) - }) - .collect::>(); - match future::try_join_all(get_futures).await { + match Self::gen_all(context, params.clone(), entry, items).await { Ok(values) => { let values_tuple_result = Python::attach(|py| externs::store_tuple(py, values));