Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions djcelery_transactions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
from celery.task import task as base_task, Task
from celery import task as base_task, Task
import djcelery_transactions.transaction_signals
from django.db import transaction
from functools import partial
Expand Down Expand Up @@ -39,15 +39,13 @@ def example(pk):

abstract = True

@classmethod
def original_apply_async(cls, *args, **kwargs):
def original_apply_async(self, *args, **kwargs):
"""Shortcut method to reach real implementation
of celery.Task.apply_sync
"""
return super(PostTransactionTask, cls).apply_async(*args, **kwargs)
return super(PostTransactionTask, self).apply_async(*args, **kwargs)

@classmethod
def apply_async(cls, *args, **kwargs):
def apply_async(self, *args, **kwargs):
# Delay the task unless the client requested otherwise or transactions
# aren't being managed (i.e. the signal handlers won't send the task).
if transaction.is_managed():
Expand All @@ -58,9 +56,9 @@ def apply_async(cls, *args, **kwargs):
transaction.set_dirty(using=kwargs['using'])
else:
transaction.set_dirty()
_get_task_queue().append((cls, args, kwargs))
_get_task_queue().append((self, args, kwargs))
else:
return cls.original_apply_async(*args, **kwargs)
return self.original_apply_async(*args, **kwargs)


def _discard_tasks(**kwargs):
Expand All @@ -78,8 +76,8 @@ def _send_tasks(**kwargs):
"""
queue = _get_task_queue()
while queue:
cls, args, kwargs = queue.pop(0)
cls.original_apply_async(*args, **kwargs)
self, args, kwargs = queue.pop(0)
self.original_apply_async(*args, **kwargs)


# A replacement decorator.
Expand Down