From ad793004e543beeeb4ca8a2ac07c42219bb9b2ed Mon Sep 17 00:00:00 2001 From: Denis Diveev Date: Mon, 12 Jan 2026 11:01:40 +0300 Subject: [PATCH] adds overload for using the decorator with bind=True and a custom base class --- celery-stubs/app/base.pyi | 52 ++++++++++++++++++++++++++++++++++----- tests/test_celery.py | 41 ++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 14 deletions(-) diff --git a/celery-stubs/app/base.pyi b/celery-stubs/app/base.pyi index 6bbad18..8d33db6 100644 --- a/celery-stubs/app/base.pyi +++ b/celery-stubs/app/base.pyi @@ -200,7 +200,7 @@ class Celery(Generic[_T_Global]): *, name: str = ..., serializer: str = ..., - bind: bool = ..., + bind: Literal[True], autoretry_for: Sequence[type[BaseException]] = ..., dont_autoretry_for: Sequence[type[BaseException]] = ..., max_retries: int | None = ..., @@ -233,14 +233,14 @@ class Celery(Generic[_T_Global]): after_return: Callable[..., Any] = ..., on_retry: Callable[..., Any] = ..., **options: Any, - ) -> Callable[[Callable[..., Any]], _T]: ... + ) -> Callable[[Callable[Concatenate[_T, _P], _R]], _T]: ... @overload def task( self, *, name: str = ..., serializer: str = ..., - bind: Literal[False] = ..., + bind: Literal[True], autoretry_for: Sequence[type[BaseException]] = ..., dont_autoretry_for: Sequence[type[BaseException]] = ..., max_retries: int | None = ..., @@ -273,14 +273,14 @@ class Celery(Generic[_T_Global]): after_return: Callable[..., Any] = ..., on_retry: Callable[..., Any] = ..., **options: Any, - ) -> Callable[[Callable[_P, _R]], _T_Global]: ... + ) -> Callable[[Callable[Concatenate[_T_Global, _P], _R]], _T_Global]: ... @overload def task( self, *, name: str = ..., serializer: str = ..., - bind: Literal[True], + bind: Literal[False] = False, autoretry_for: Sequence[type[BaseException]] = ..., dont_autoretry_for: Sequence[type[BaseException]] = ..., max_retries: int | None = ..., @@ -313,7 +313,47 @@ class Celery(Generic[_T_Global]): after_return: Callable[..., Any] = ..., on_retry: Callable[..., Any] = ..., **options: Any, - ) -> Callable[[Callable[Concatenate[_T_Global, _P], _R]], _T_Global]: ... + ) -> Callable[[Callable[_P, _R]], _T_Global]: ... + @overload + def task( + self, + *, + name: str = ..., + serializer: str = ..., + bind: bool = ..., + autoretry_for: Sequence[type[BaseException]] = ..., + dont_autoretry_for: Sequence[type[BaseException]] = ..., + max_retries: int | None = ..., + default_retry_delay: int = ..., + acks_late: bool = ..., + ignore_result: bool = ..., + soft_time_limit: int = ..., + time_limit: int = ..., + base: type[_T], + retry_kwargs: dict[str, Any] = ..., + retry_backoff: bool | int = ..., + retry_backoff_max: int = ..., + retry_jitter: bool = ..., + typing: bool = ..., + rate_limit: str | None = ..., + trail: bool = ..., + send_events: bool = ..., + store_errors_even_if_ignored: bool = ..., + autoregister: bool = ..., + track_started: bool = ..., + acks_on_failure_or_timeout: bool = ..., + reject_on_worker_lost: bool = ..., + throws: tuple[type[Exception], ...] = ..., + expires: float | datetime.datetime | None = ..., + priority: int | None = ..., + resultrepr_maxsize: int = ..., + request_stack: _LocalStack[Context] = ..., + abstract: bool = ..., + queue: str = ..., + after_return: Callable[..., Any] = ..., + on_retry: Callable[..., Any] = ..., + **options: Any, + ) -> Callable[[Callable[..., Any]], _T]: ... def type_checker( self, fun: Callable[_P, _T_1], bound: bool = False ) -> Callable[_P, _T_1]: ... diff --git a/tests/test_celery.py b/tests/test_celery.py index 502d30b..59a7791 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -12,7 +12,7 @@ from celery.result import AsyncResult, allow_join_result, denied_join_result from celery.schedules import crontab from celery.utils.log import get_task_logger -from typing_extensions import override +from typing_extensions import assert_type, override if TYPE_CHECKING: from collections.abc import Iterator @@ -56,18 +56,24 @@ def db(self) -> DB: @app.task(base=DatabaseTask) def process_rows(param_1: int) -> None: + assert_type(process_rows, DatabaseTask) + for row in process_rows.db.table.all(): print(row) @shared_task(base=DatabaseTask) def process_rows_2(param_1: int) -> None: + assert_type(process_rows_2, DatabaseTask) + for row in process_rows_2.db.table.all(): print(row) @shared_task(base=DatabaseTask, bind=True) def process_rows_3(self: DatabaseTask, param_1: int) -> None: + assert_type(process_rows_3, DatabaseTask) + for row in process_rows_3.db.table.all(): print(row) @@ -76,6 +82,8 @@ def process_rows_3(self: DatabaseTask, param_1: int) -> None: # pyright and mypy will report that the typeignore is unnecessary. @shared_task(base=DatabaseTask, bind=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] def process_rows_4(self: int, param_1: int) -> None: + assert_type(process_rows_4, DatabaseTask) + for row in process_rows_4.db.table.all(): print(row) @@ -83,28 +91,45 @@ def process_rows_4(self: int, param_1: int) -> None: database_app = Celery[DatabaseTask]() -@database_app.task(name="main.process_rows_4") -def process_rows_5(param_1: int) -> None: +@database_app.task(name="main.process_rows_5") +def process_rows_5(param_1: DatabaseTask) -> None: + assert_type(process_rows_5, DatabaseTask) for row in process_rows_5.db.table.all(): print(row) -@database_app.task(name="main.process_rows_5", bind=True) +@database_app.task(name="main.process_rows_6", bind=True) def process_rows_6(self: DatabaseTask, param_1: int) -> None: + assert_type(process_rows_6, DatabaseTask) + for row in process_rows_6.db.table.all(): print(row) # Here, a typeignore is needed so that when the overload stops working correctly, # pyright and mypy will report that the typeignore is unnecessary. -@database_app.task(name="main.process_rows_5", bind=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +@database_app.task(name="main.process_rows_7", bind=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] def process_rows_7(self: int, param_1: int) -> None: + assert_type(process_rows_7, DatabaseTask) + for row in process_rows_7.db.table.all(): print(row) +# Here, a typeignore is needed so that when the overload stops working correctly, +# pyright and mypy will report that the typeignore is unnecessary. +@database_app.task(name="main.process_rows_8", bind=True, base=Task[..., None]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +def binded_task_8_fail(self: int, param_1: int) -> None: + pass + + +@database_app.task(name="main.process_rows_8", bind=True, base=Task[[int], None]) +def binded_task_8_ok(self: Task[[int], None], param_1: int) -> None: + assert_type(binded_task_8_ok, Task[[int], None]) + + @app.task(bind=True, default_retry_delay=10) -def send_twitter_status(self: Task[Any, Any], oauth: str, tweet: str) -> None: +def send_twitter_status(self: Task[[str, str], Any], oauth: str, tweet: str) -> None: try: print("fetch stuff") except KeyError as exc: @@ -150,7 +175,7 @@ def foo() -> None: print("foo") -foo.name +assert_type(foo.name, str) app_2 = celery.Celery("worker") @@ -319,7 +344,7 @@ def baz() -> None: def test_celery_top_level_exports() -> None: - celery.Celery + celery.Celery[Task[Any, Any]] celery.Signature[Any] celery.Task[Any, Any] celery.chain