From b872a6ca568b0541ca46c2f38135739d214b2f10 Mon Sep 17 00:00:00 2001 From: Matthew Elwell Date: Wed, 27 Mar 2024 17:38:22 +0100 Subject: [PATCH] fix: prevent tasks dying from temporary loss of db connection (#3674) --- api/task_processor/threads.py | 22 ++++++++--- api/tests/unit/task_processor/conftest.py | 29 ++++++++++++++ .../test_unit_task_processor_decorators.py | 25 ++++-------- .../test_unit_task_processor_threads.py | 38 +++++++++++++++++++ 4 files changed, 91 insertions(+), 23 deletions(-) create mode 100644 api/tests/unit/task_processor/test_unit_task_processor_threads.py diff --git a/api/task_processor/threads.py b/api/task_processor/threads.py index d41de67949f9..0cc995274cce 100644 --- a/api/task_processor/threads.py +++ b/api/task_processor/threads.py @@ -1,7 +1,9 @@ import logging import time +import traceback from threading import Thread +from django.db import DatabaseError from django.utils import timezone from task_processor.processor import run_recurring_tasks, run_tasks @@ -27,12 +29,22 @@ def __init__( def run(self) -> None: while not self._stopped: self.last_checked_for_tasks = timezone.now() - try: - run_tasks(self.queue_pop_size) - except Exception as e: - logger.exception(e) - run_recurring_tasks(self.queue_pop_size) + self.run_iteration() time.sleep(self.sleep_interval_millis / 1000) + def run_iteration(self) -> None: + try: + run_tasks(self.queue_pop_size) + run_recurring_tasks(self.queue_pop_size) + except DatabaseError as e: + # To prevent task threads from dying if they get an error retrieving the tasks from the + # database this will allow the thread to continue trying to retrieve tasks if it can + # successfully re-establish a connection to the database. + # TODO: is this also what is causing tasks to get stuck as locked? Can we unlock + # tasks here? + + logger.error("Received database error retrieving tasks: %s.", e) + logger.debug(traceback.format_exc()) + def stop(self): self._stopped = True diff --git a/api/tests/unit/task_processor/conftest.py b/api/tests/unit/task_processor/conftest.py index 107544f72d71..6f38020ec571 100644 --- a/api/tests/unit/task_processor/conftest.py +++ b/api/tests/unit/task_processor/conftest.py @@ -1,6 +1,35 @@ +import logging +import typing + import pytest @pytest.fixture def run_by_processor(monkeypatch): monkeypatch.setenv("RUN_BY_PROCESSOR", "True") + + +class GetTaskProcessorCaplog(typing.Protocol): + def __call__( + self, log_level: str | int = logging.INFO + ) -> pytest.LogCaptureFixture: ... + + +@pytest.fixture +def get_task_processor_caplog( + caplog: pytest.LogCaptureFixture, +) -> GetTaskProcessorCaplog: + # caplog doesn't allow you to capture logging outputs from loggers that don't + # propagate to root. Quick hack here to get the task_processor logger to + # propagate. + # TODO: look into using loguru. + + def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture: + task_processor_logger = logging.getLogger("task_processor") + task_processor_logger.propagate = True + # Assume required level for the logger. + task_processor_logger.setLevel(log_level) + caplog.set_level(log_level) + return caplog + + return _inner diff --git a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py index ba472aed1a40..f88046d38925 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -1,5 +1,4 @@ import json -import logging from datetime import timedelta from unittest.mock import MagicMock @@ -16,19 +15,7 @@ from task_processor.models import RecurringTask, Task, TaskPriority from task_processor.task_registry import get_task from task_processor.task_run_method import TaskRunMethod - - -@pytest.fixture -def capture_task_processor_logger(caplog: pytest.LogCaptureFixture) -> None: - # caplog doesn't allow you to capture logging outputs from loggers that don't - # propagate to root. Quick hack here to get the task_processor logger to - # propagate. - # TODO: look into using loguru. - task_processor_logger = logging.getLogger("task_processor") - task_processor_logger.propagate = True - # Assume required level for the logger. - task_processor_logger.setLevel(logging.INFO) - caplog.set_level(logging.INFO) +from tests.unit.task_processor.conftest import GetTaskProcessorCaplog @pytest.fixture @@ -44,11 +31,12 @@ def mock_thread_class( @pytest.mark.django_db def test_register_task_handler_run_in_thread__transaction_commit__true__default( - capture_task_processor_logger: None, - caplog: pytest.LogCaptureFixture, + get_task_processor_caplog: GetTaskProcessorCaplog, mock_thread_class: MagicMock, ) -> None: # Given + caplog = get_task_processor_caplog() + @register_task_handler() def my_function(*args: str, **kwargs: str) -> None: pass @@ -77,11 +65,12 @@ def my_function(*args: str, **kwargs: str) -> None: def test_register_task_handler_run_in_thread__transaction_commit__false( - capture_task_processor_logger: None, - caplog: pytest.LogCaptureFixture, + get_task_processor_caplog: GetTaskProcessorCaplog, mock_thread_class: MagicMock, ) -> None: # Given + caplog = get_task_processor_caplog() + @register_task_handler(transaction_on_commit=False) def my_function(*args, **kwargs): pass diff --git a/api/tests/unit/task_processor/test_unit_task_processor_threads.py b/api/tests/unit/task_processor/test_unit_task_processor_threads.py new file mode 100644 index 000000000000..391c3fd14325 --- /dev/null +++ b/api/tests/unit/task_processor/test_unit_task_processor_threads.py @@ -0,0 +1,38 @@ +import logging + +from django.db import DatabaseError +from pytest_django.fixtures import SettingsWrapper +from pytest_mock import MockerFixture + +from task_processor.threads import TaskRunner +from tests.unit.task_processor.conftest import GetTaskProcessorCaplog + + +def test_task_runner_is_resilient_to_database_errors( + db: None, + mocker: MockerFixture, + get_task_processor_caplog: GetTaskProcessorCaplog, + settings: SettingsWrapper, +) -> None: + # Given + caplog = get_task_processor_caplog(logging.DEBUG) + + task_runner = TaskRunner() + mocker.patch( + "task_processor.threads.run_tasks", side_effect=DatabaseError("Database error") + ) + + # When + task_runner.run_iteration() + + # Then + assert len(caplog.records) == 2 + + assert caplog.records[0].levelno == logging.ERROR + assert ( + caplog.records[0].message + == "Received database error retrieving tasks: Database error." + ) + + assert caplog.records[1].levelno == logging.DEBUG + assert caplog.records[1].message.startswith("Traceback")