diff --git a/CHANGELOG.md b/CHANGELOG.md index 3098be39c9..4c85701883 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,6 +106,9 @@ These changes are available on the `master` branch, but have not yet been releas ([#2714](https://github.com/Pycord-Development/pycord/pull/2714)) - Added the ability to pass a `datetime.time` object to `format_dt`. ([#2747](https://github.com/Pycord-Development/pycord/pull/2747)) +- Added the ability to pass an `overlap` parameter to the `loop` decorator and `Loop` + class, allowing concurrent iterations if enabled. + ([#2765](https://github.com/Pycord-Development/pycord/pull/2765)) - Added various missing channel parameters and allow `default_reaction_emoji` to be `None`. ([#2772](https://github.com/Pycord-Development/pycord/pull/2772)) - Added support for type hinting slash command options with `typing.Annotated`. diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index af34cc6844..9bdde87f23 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -26,6 +26,7 @@ from __future__ import annotations import asyncio +import contextvars import datetime import inspect import sys @@ -46,6 +47,9 @@ LF = TypeVar("LF", bound=_func) FT = TypeVar("FT", bound=_func) ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) +_current_loop_ctx: contextvars.ContextVar[int] = contextvars.ContextVar( + "_current_loop_ctx", default=None +) class SleepHandle: @@ -59,10 +63,14 @@ def __init__( relative_delta = discord.utils.compute_timedelta(dt) self.handle = loop.call_later(relative_delta, future.set_result, True) + def _set_result_safe(self): + if not self.future.done(): + self.future.set_result(True) + def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) + self.handle = self.loop.call_later(relative_delta, self._set_result_safe) def wait(self) -> asyncio.Future[Any]: return self.future @@ -91,10 +99,12 @@ def __init__( count: int | None, reconnect: bool, loop: asyncio.AbstractEventLoop, + overlap: bool | int, ) -> None: self.coro: LF = coro self.reconnect: bool = reconnect self.loop: asyncio.AbstractEventLoop = loop + self.overlap: bool | int = overlap self.count: int | None = count self._current_loop = 0 self._handle: SleepHandle = MISSING @@ -115,6 +125,7 @@ def __init__( self._is_being_cancelled = False self._has_failed = False self._stop_next_iteration = False + self._tasks: set[asyncio.Task[Any]] = set() if self.count is not None and self.count <= 0: raise ValueError("count must be greater than 0 or None.") @@ -128,6 +139,29 @@ def __init__( raise TypeError( f"Expected coroutine function, not {type(self.coro).__name__!r}." ) + if isinstance(overlap, bool): + if overlap: + self._run_with_semaphore = self._run_direct + elif isinstance(overlap, int): + if overlap <= 1: + raise ValueError("overlap as an integer must be greater than 1.") + self._semaphore = asyncio.Semaphore(overlap) + self._run_with_semaphore = self._semaphore_runner_factory() + else: + raise TypeError("overlap must be a bool or a positive integer.") + + async def _run_direct(self, *args: Any, **kwargs: Any) -> None: + """Run the coroutine directly.""" + await self.coro(*args, **kwargs) + + def _semaphore_runner_factory(self) -> Callable[..., Awaitable[None]]: + """Return a function that runs the coroutine with a semaphore.""" + + async def runner(*args: Any, **kwargs: Any) -> None: + async with self._semaphore: + await self.coro(*args, **kwargs) + + return runner async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: coro = getattr(self, f"_{name}") @@ -166,7 +200,18 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: self._last_iteration = self._next_iteration self._next_iteration = self._get_next_sleep_time() try: - await self.coro(*args, **kwargs) + token = _current_loop_ctx.set(self._current_loop) + if not self.overlap: + await self.coro(*args, **kwargs) + else: + task = asyncio.create_task( + self._run_with_semaphore(*args, **kwargs), + name=f"pycord-loop-{self.coro.__name__}-{self._current_loop}", + ) + task.add_done_callback(self._tasks.discard) + self._tasks.add(task) + + _current_loop_ctx.reset(token) self._last_iteration_failed = False backoff = ExponentialBackoff() except self._valid_exception: @@ -192,6 +237,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: except asyncio.CancelledError: self._is_being_cancelled = True + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) raise except Exception as exc: self._has_failed = True @@ -218,6 +266,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]: count=self.count, reconnect=self.reconnect, loop=self.loop, + overlap=self.overlap, ) copy._injected = obj copy._before_loop = self._before_loop @@ -269,7 +318,11 @@ def time(self) -> list[datetime.time] | None: @property def current_loop(self) -> int: """The current iteration of the loop.""" - return self._current_loop + return ( + _current_loop_ctx.get() + if _current_loop_ctx.get() is not None + else self._current_loop + ) @property def next_iteration(self) -> datetime.datetime | None: @@ -738,6 +791,7 @@ def loop( count: int | None = None, reconnect: bool = True, loop: asyncio.AbstractEventLoop = MISSING, + overlap: bool | int = False, ) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -773,6 +827,11 @@ def loop( loop: :class:`asyncio.AbstractEventLoop` The loop to use to register the task, if not given defaults to :func:`asyncio.get_event_loop`. + overlap: Union[:class:`bool`, :class:`int`] + Controls whether overlapping executions of the task loop are allowed. + Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs. + + .. versionadded:: 2.7 Raises ------ @@ -793,6 +852,7 @@ def decorator(func: LF) -> Loop[LF]: time=time, reconnect=reconnect, loop=loop, + overlap=overlap, ) return decorator diff --git a/examples/background_task.py b/examples/background_task.py index ebbf2a36af..9bba50abda 100644 --- a/examples/background_task.py +++ b/examples/background_task.py @@ -1,3 +1,5 @@ +import asyncio +import random from datetime import time, timezone import discord @@ -10,7 +12,6 @@ def __init__(self, *args, **kwargs): # An attribute we can access from our task self.counter = 0 - # Start the tasks to run in the background self.my_background_task.start() self.time_task.start() @@ -37,6 +38,32 @@ async def time_task(self): async def before_my_task(self): await self.wait_until_ready() # Wait until the bot logs in + # Schedule every 10s; each run takes between 5 to 20s. With overlap=2, at most 2 runs + # execute concurrently so we don't build an ever-growing backlog. + @tasks.loop(seconds=10, overlap=2) + async def fetch_status_task(self): + """ + Practical overlap use-case: + + Poll an external service and post a short summary. Each poll may take + between 5 to 20s due to network latency or rate limits, but we want fresh data + every 10s. Allowing a small amount of overlap avoids drifting schedules + without opening the floodgates to unlimited concurrency. + """ + print(f"[status] start run #{self.fetch_status_task.current_loop}") + + # Simulate slow I/O (e.g., HTTP requests, DB queries, file I/O) + await asyncio.sleep(random.randint(5, 20)) + + channel = self.get_channel(1234567) # Replace with your channel ID + msg = f"[status] run #{self.fetch_status_task.current_loop} complete" + if channel: + await channel.send(msg) + else: + print(msg) + + print(f"[status] end run #{self.fetch_status_task.current_loop}") + client = MyClient() client.run("TOKEN")