|
8 | 8 | from datetime import datetime, timezone |
9 | 9 | from functools import partial |
10 | 10 |
|
| 11 | +from discord.errors import Forbidden |
| 12 | + |
11 | 13 | from pydis_core.utils import logging |
| 14 | +from pydis_core.utils.error_handling import handle_forbidden_from_block |
12 | 15 |
|
13 | 16 | _background_tasks: set[asyncio.Task] = set() |
14 | 17 |
|
@@ -77,7 +80,7 @@ def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: |
77 | 80 | coroutine.close() |
78 | 81 | return |
79 | 82 |
|
80 | | - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") |
| 83 | + task = asyncio.create_task(_coro_wrapper(coroutine), name=f"{self.name}_{task_id}") |
81 | 84 | task.add_done_callback(partial(self._task_done_callback, task_id)) |
82 | 85 |
|
83 | 86 | self._scheduled_tasks[task_id] = task |
@@ -238,21 +241,29 @@ def create_task( |
238 | 241 | asyncio.Task: The wrapped task. |
239 | 242 | """ |
240 | 243 | if event_loop is not None: |
241 | | - task = event_loop.create_task(coro, **kwargs) |
| 244 | + task = event_loop.create_task(_coro_wrapper(coro), **kwargs) |
242 | 245 | else: |
243 | | - task = asyncio.create_task(coro, **kwargs) |
| 246 | + task = asyncio.create_task(_coro_wrapper(coro), **kwargs) |
244 | 247 |
|
245 | 248 | _background_tasks.add(task) |
246 | 249 | task.add_done_callback(_background_tasks.discard) |
247 | 250 | task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) |
248 | 251 | return task |
249 | 252 |
|
250 | 253 |
|
| 254 | +async def _coro_wrapper(coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN]) -> None: |
| 255 | + """Wraps `coro` in a try/except block that will handle 90001 Forbidden errors.""" |
| 256 | + try: |
| 257 | + await coro |
| 258 | + except Forbidden as e: |
| 259 | + await handle_forbidden_from_block(e) |
| 260 | + |
| 261 | + |
251 | 262 | def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None: |
252 | | - """Retrieve and log the exception raised in ``task`` if one exists.""" |
| 263 | + """Retrieve and log the exception raised in ``task``, if one exists and it's not suppressed.""" |
253 | 264 | with contextlib.suppress(asyncio.CancelledError): |
254 | 265 | exception = task.exception() |
255 | | - # Log the exception if one exists. |
| 266 | + # Log the exception if one exists and it's not suppressed/handled. |
256 | 267 | if exception and not isinstance(exception, suppressed_exceptions): |
257 | 268 | log = logging.get_logger(__name__) |
258 | 269 | log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) |
0 commit comments