Skip to content

Commit

Permalink
fix: prevent tasks dying from temporary loss of db connection (#3674)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewelwell authored Mar 27, 2024
1 parent 0806dbc commit b872a6c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 23 deletions.
22 changes: 17 additions & 5 deletions api/task_processor/threads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import time
import traceback
from threading import Thread

from django.db import DatabaseError
from django.utils import timezone

from task_processor.processor import run_recurring_tasks, run_tasks
Expand All @@ -27,12 +29,22 @@ def __init__(
def run(self) -> None:
while not self._stopped:
self.last_checked_for_tasks = timezone.now()
try:
run_tasks(self.queue_pop_size)
except Exception as e:
logger.exception(e)
run_recurring_tasks(self.queue_pop_size)
self.run_iteration()
time.sleep(self.sleep_interval_millis / 1000)

def run_iteration(self) -> None:
try:
run_tasks(self.queue_pop_size)
run_recurring_tasks(self.queue_pop_size)
except DatabaseError as e:
# To prevent task threads from dying if they get an error retrieving the tasks from the
# database this will allow the thread to continue trying to retrieve tasks if it can
# successfully re-establish a connection to the database.
# TODO: is this also what is causing tasks to get stuck as locked? Can we unlock
# tasks here?

logger.error("Received database error retrieving tasks: %s.", e)
logger.debug(traceback.format_exc())

def stop(self):
self._stopped = True
29 changes: 29 additions & 0 deletions api/tests/unit/task_processor/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
import logging
import typing

import pytest


@pytest.fixture
def run_by_processor(monkeypatch):
monkeypatch.setenv("RUN_BY_PROCESSOR", "True")


class GetTaskProcessorCaplog(typing.Protocol):
def __call__(
self, log_level: str | int = logging.INFO
) -> pytest.LogCaptureFixture: ...


@pytest.fixture
def get_task_processor_caplog(
caplog: pytest.LogCaptureFixture,
) -> GetTaskProcessorCaplog:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
# TODO: look into using loguru.

def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture:
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
# Assume required level for the logger.
task_processor_logger.setLevel(log_level)
caplog.set_level(log_level)
return caplog

return _inner
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import logging
from datetime import timedelta
from unittest.mock import MagicMock

Expand All @@ -16,19 +15,7 @@
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import get_task
from task_processor.task_run_method import TaskRunMethod


@pytest.fixture
def capture_task_processor_logger(caplog: pytest.LogCaptureFixture) -> None:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
# TODO: look into using loguru.
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
# Assume required level for the logger.
task_processor_logger.setLevel(logging.INFO)
caplog.set_level(logging.INFO)
from tests.unit.task_processor.conftest import GetTaskProcessorCaplog


@pytest.fixture
Expand All @@ -44,11 +31,12 @@ def mock_thread_class(

@pytest.mark.django_db
def test_register_task_handler_run_in_thread__transaction_commit__true__default(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
mock_thread_class: MagicMock,
) -> None:
# Given
caplog = get_task_processor_caplog()

@register_task_handler()
def my_function(*args: str, **kwargs: str) -> None:
pass
Expand Down Expand Up @@ -77,11 +65,12 @@ def my_function(*args: str, **kwargs: str) -> None:


def test_register_task_handler_run_in_thread__transaction_commit__false(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
mock_thread_class: MagicMock,
) -> None:
# Given
caplog = get_task_processor_caplog()

@register_task_handler(transaction_on_commit=False)
def my_function(*args, **kwargs):
pass
Expand Down
38 changes: 38 additions & 0 deletions api/tests/unit/task_processor/test_unit_task_processor_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

from django.db import DatabaseError
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture

from task_processor.threads import TaskRunner
from tests.unit.task_processor.conftest import GetTaskProcessorCaplog


def test_task_runner_is_resilient_to_database_errors(
db: None,
mocker: MockerFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
settings: SettingsWrapper,
) -> None:
# Given
caplog = get_task_processor_caplog(logging.DEBUG)

task_runner = TaskRunner()
mocker.patch(
"task_processor.threads.run_tasks", side_effect=DatabaseError("Database error")
)

# When
task_runner.run_iteration()

# Then
assert len(caplog.records) == 2

assert caplog.records[0].levelno == logging.ERROR
assert (
caplog.records[0].message
== "Received database error retrieving tasks: Database error."
)

assert caplog.records[1].levelno == logging.DEBUG
assert caplog.records[1].message.startswith("Traceback")

0 comments on commit b872a6c

Please sign in to comment.