diff --git a/django_celery_results/backends/database.py b/django_celery_results/backends/database.py index a4e364a6..664c7120 100644 --- a/django_celery_results/backends/database.py +++ b/django_celery_results/backends/database.py @@ -7,7 +7,7 @@ from celery.result import GroupResult, allow_join_result, result_from_tuple from celery.utils.log import get_logger from celery.utils.serialization import b64decode, b64encode -from django.db import connection, router, transaction +from django.db import connection, connections, router, transaction from django.db.models.functions import Now from django.db.utils import InterfaceError from kombu.exceptions import DecodeError @@ -120,6 +120,17 @@ def _store_result( using=None ): """Store return value and status of an executed task.""" + + # If a task has been running long, it may have exceeded + # the max db age and/or the database connection + # may have been ended due to being idle for too long. + # As a safety, before we submit the result, + # we ensure it still has a valid connection, just like + # Django does after a request to ensure a + # clean connection for the next request. + (connections[router.db_for_write(self.TaskModel)] + .close_if_unusable_or_obsolete()) + content_type, content_encoding, result = self.encode_content(result) meta = { @@ -147,7 +158,6 @@ def _store_result( if status == states.STARTED: task_props['date_started'] = Now() - self.TaskModel._default_manager.store_result(**task_props) return result diff --git a/t/proj/settings.py b/t/proj/settings.py index 6e974b51..d347f14d 100644 --- a/t/proj/settings.py +++ b/t/proj/settings.py @@ -36,6 +36,7 @@ 'OPTIONS': { 'connect_timeout': 1000, }, + 'CONN_MAX_AGE': None, }, 'secondary': { 'ENGINE': 'django.db.backends.postgresql', @@ -50,6 +51,7 @@ 'TEST': { 'MIRROR': 'default', }, + 'CONN_MAX_AGE': None, }, 'read-only': { 'ENGINE': 'django.db.backends.postgresql', @@ -65,6 +67,7 @@ 'TEST': { 'MIRROR': 'default', }, + 'CONN_MAX_AGE': None, }, } except ImportError: diff --git a/t/unit/backends/test_database.py b/t/unit/backends/test_database.py index 8baa6cc0..2c8b36e0 100644 --- a/t/unit/backends/test_database.py +++ b/t/unit/backends/test_database.py @@ -2,6 +2,7 @@ import json import pickle import re +import time from unittest import mock import celery @@ -12,6 +13,7 @@ from celery.utils.serialization import b64decode from celery.worker.request import Request from celery.worker.strategy import hybrid_to_proto2 +from django.db import connections, router from django.test import TransactionTestCase from django_celery_results.backends.database import DatabaseBackend @@ -24,7 +26,7 @@ def __init__(self, data): self.data = data -@pytest.mark.django_db() +@pytest.mark.django_db(transaction=True) @pytest.mark.usefixtures('depends_on_current_app') class test_DatabaseBackend: @@ -550,6 +552,25 @@ def test_backend__task_result_meta_injection(self): tr = TaskResult.objects.get(task_id=tid2) assert json.loads(tr.meta) == {'key': 'value', 'children': []} + def test_backend__task_result_closes_stale_connection(self): + tid = uuid() + request = self._create_request( + task_id=tid, + name='my_task', + args=[], + kwargs={}, + task_protocol=1, + ) + # simulate a stale connection by setting the close time + # to the current time + db_conn_wrapper = connections[router.db_for_write(self.b.TaskModel)] + db_conn_wrapper.close_at = time.monotonic() + current_db_connection = db_conn_wrapper.connection + self.b.mark_as_done(tid, None, request=request) + # Validate the connection was replaced in the process + # of saving the task + assert current_db_connection is not db_conn_wrapper.connection + def test_backend__task_result_date(self): tid2 = uuid()