Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions celery-stubs/app/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Celery(Generic[_T_Global]):
*,
name: str = ...,
serializer: str = ...,
bind: bool = ...,
bind: Literal[True],
autoretry_for: Sequence[type[BaseException]] = ...,
dont_autoretry_for: Sequence[type[BaseException]] = ...,
max_retries: int | None = ...,
Expand Down Expand Up @@ -233,14 +233,14 @@ class Celery(Generic[_T_Global]):
after_return: Callable[..., Any] = ...,
on_retry: Callable[..., Any] = ...,
**options: Any,
) -> Callable[[Callable[..., Any]], _T]: ...
) -> Callable[[Callable[Concatenate[_T, _P], _R]], _T]: ...
@overload
def task(
self,
*,
name: str = ...,
serializer: str = ...,
bind: Literal[False] = ...,
bind: Literal[True],
autoretry_for: Sequence[type[BaseException]] = ...,
dont_autoretry_for: Sequence[type[BaseException]] = ...,
max_retries: int | None = ...,
Expand Down Expand Up @@ -273,14 +273,14 @@ class Celery(Generic[_T_Global]):
after_return: Callable[..., Any] = ...,
on_retry: Callable[..., Any] = ...,
**options: Any,
) -> Callable[[Callable[_P, _R]], _T_Global]: ...
) -> Callable[[Callable[Concatenate[_T_Global, _P], _R]], _T_Global]: ...
@overload
def task(
self,
*,
name: str = ...,
serializer: str = ...,
bind: Literal[True],
bind: Literal[False] = False,
autoretry_for: Sequence[type[BaseException]] = ...,
dont_autoretry_for: Sequence[type[BaseException]] = ...,
max_retries: int | None = ...,
Expand Down Expand Up @@ -313,7 +313,47 @@ class Celery(Generic[_T_Global]):
after_return: Callable[..., Any] = ...,
on_retry: Callable[..., Any] = ...,
**options: Any,
) -> Callable[[Callable[Concatenate[_T_Global, _P], _R]], _T_Global]: ...
) -> Callable[[Callable[_P, _R]], _T_Global]: ...
@overload
def task(
self,
*,
name: str = ...,
serializer: str = ...,
bind: bool = ...,
autoretry_for: Sequence[type[BaseException]] = ...,
dont_autoretry_for: Sequence[type[BaseException]] = ...,
max_retries: int | None = ...,
default_retry_delay: int = ...,
acks_late: bool = ...,
ignore_result: bool = ...,
soft_time_limit: int = ...,
time_limit: int = ...,
base: type[_T],
retry_kwargs: dict[str, Any] = ...,
retry_backoff: bool | int = ...,
retry_backoff_max: int = ...,
retry_jitter: bool = ...,
typing: bool = ...,
rate_limit: str | None = ...,
trail: bool = ...,
send_events: bool = ...,
store_errors_even_if_ignored: bool = ...,
autoregister: bool = ...,
track_started: bool = ...,
acks_on_failure_or_timeout: bool = ...,
reject_on_worker_lost: bool = ...,
throws: tuple[type[Exception], ...] = ...,
expires: float | datetime.datetime | None = ...,
priority: int | None = ...,
resultrepr_maxsize: int = ...,
request_stack: _LocalStack[Context] = ...,
abstract: bool = ...,
queue: str = ...,
after_return: Callable[..., Any] = ...,
on_retry: Callable[..., Any] = ...,
**options: Any,
) -> Callable[[Callable[..., Any]], _T]: ...
def type_checker(
self, fun: Callable[_P, _T_1], bound: bool = False
) -> Callable[_P, _T_1]: ...
Expand Down
41 changes: 33 additions & 8 deletions tests/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from celery.result import AsyncResult, allow_join_result, denied_join_result
from celery.schedules import crontab
from celery.utils.log import get_task_logger
from typing_extensions import override
from typing_extensions import assert_type, override

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -56,18 +56,24 @@ def db(self) -> DB:

@app.task(base=DatabaseTask)
def process_rows(param_1: int) -> None:
assert_type(process_rows, DatabaseTask)

for row in process_rows.db.table.all():
print(row)


@shared_task(base=DatabaseTask)
def process_rows_2(param_1: int) -> None:
assert_type(process_rows_2, DatabaseTask)

for row in process_rows_2.db.table.all():
print(row)


@shared_task(base=DatabaseTask, bind=True)
def process_rows_3(self: DatabaseTask, param_1: int) -> None:
assert_type(process_rows_3, DatabaseTask)

for row in process_rows_3.db.table.all():
print(row)

Expand All @@ -76,35 +82,54 @@ def process_rows_3(self: DatabaseTask, param_1: int) -> None:
# pyright and mypy will report that the typeignore is unnecessary.
@shared_task(base=DatabaseTask, bind=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def process_rows_4(self: int, param_1: int) -> None:
assert_type(process_rows_4, DatabaseTask)

for row in process_rows_4.db.table.all():
print(row)


database_app = Celery[DatabaseTask]()


@database_app.task(name="main.process_rows_4")
def process_rows_5(param_1: int) -> None:
@database_app.task(name="main.process_rows_5")
def process_rows_5(param_1: DatabaseTask) -> None:
assert_type(process_rows_5, DatabaseTask)
for row in process_rows_5.db.table.all():
print(row)


@database_app.task(name="main.process_rows_5", bind=True)
@database_app.task(name="main.process_rows_6", bind=True)
def process_rows_6(self: DatabaseTask, param_1: int) -> None:
assert_type(process_rows_6, DatabaseTask)

for row in process_rows_6.db.table.all():
print(row)


# Here, a typeignore is needed so that when the overload stops working correctly,
# pyright and mypy will report that the typeignore is unnecessary.
@database_app.task(name="main.process_rows_5", bind=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
@database_app.task(name="main.process_rows_7", bind=True) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def process_rows_7(self: int, param_1: int) -> None:
assert_type(process_rows_7, DatabaseTask)

for row in process_rows_7.db.table.all():
print(row)


# Here, a typeignore is needed so that when the overload stops working correctly,
# pyright and mypy will report that the typeignore is unnecessary.
@database_app.task(name="main.process_rows_8", bind=True, base=Task[..., None]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def binded_task_8_fail(self: int, param_1: int) -> None:
pass


@database_app.task(name="main.process_rows_8", bind=True, base=Task[[int], None])
def binded_task_8_ok(self: Task[[int], None], param_1: int) -> None:
assert_type(binded_task_8_ok, Task[[int], None])


@app.task(bind=True, default_retry_delay=10)
def send_twitter_status(self: Task[Any, Any], oauth: str, tweet: str) -> None:
def send_twitter_status(self: Task[[str, str], Any], oauth: str, tweet: str) -> None:
try:
print("fetch stuff")
except KeyError as exc:
Expand Down Expand Up @@ -150,7 +175,7 @@ def foo() -> None:
print("foo")


foo.name
assert_type(foo.name, str)

app_2 = celery.Celery("worker")

Expand Down Expand Up @@ -319,7 +344,7 @@ def baz() -> None:


def test_celery_top_level_exports() -> None:
celery.Celery
celery.Celery[Task[Any, Any]]
celery.Signature[Any]
celery.Task[Any, Any]
celery.chain
Expand Down
Loading