Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent tasks dying from temporary loss of db connection #3674

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions api/task_processor/threads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import time
import traceback
from threading import Thread

from django.db import DatabaseError
from django.utils import timezone

from task_processor.processor import run_recurring_tasks, run_tasks
Expand All @@ -27,12 +29,22 @@ def __init__(
def run(self) -> None:
while not self._stopped:
self.last_checked_for_tasks = timezone.now()
try:
run_tasks(self.queue_pop_size)
except Exception as e:
logger.exception(e)
run_recurring_tasks(self.queue_pop_size)
self.run_iteration()
time.sleep(self.sleep_interval_millis / 1000)

def run_iteration(self) -> None:
try:
run_tasks(self.queue_pop_size)
run_recurring_tasks(self.queue_pop_size)
except DatabaseError as e:
# To prevent task threads from dying if they get an error retrieving the tasks from the
# database this will allow the thread to continue trying to retrieve tasks if it can
# successfully re-establish a connection to the database.
# TODO: is this also what is causing tasks to get stuck as locked? Can we unlock
# tasks here?

logger.error("Received database error retrieving tasks: %s.", e)
logger.debug(traceback.format_exc())

def stop(self):
self._stopped = True
29 changes: 29 additions & 0 deletions api/tests/unit/task_processor/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
import logging
import typing

import pytest


@pytest.fixture
def run_by_processor(monkeypatch):
monkeypatch.setenv("RUN_BY_PROCESSOR", "True")


class GetTaskProcessorCaplog(typing.Protocol):
def __call__(
self, log_level: str | int = logging.INFO
) -> pytest.LogCaptureFixture: ...


@pytest.fixture
def get_task_processor_caplog(
caplog: pytest.LogCaptureFixture,
) -> GetTaskProcessorCaplog:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
# TODO: look into using loguru.

def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture:
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
# Assume required level for the logger.
task_processor_logger.setLevel(log_level)
caplog.set_level(log_level)
return caplog

return _inner
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import logging
from datetime import timedelta
from unittest.mock import MagicMock

Expand All @@ -16,19 +15,7 @@
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import get_task
from task_processor.task_run_method import TaskRunMethod


@pytest.fixture
def capture_task_processor_logger(caplog: pytest.LogCaptureFixture) -> None:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
# TODO: look into using loguru.
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
# Assume required level for the logger.
task_processor_logger.setLevel(logging.INFO)
caplog.set_level(logging.INFO)
from tests.unit.task_processor.conftest import GetTaskProcessorCaplog


@pytest.fixture
Expand All @@ -44,11 +31,12 @@ def mock_thread_class(

@pytest.mark.django_db
def test_register_task_handler_run_in_thread__transaction_commit__true__default(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
mock_thread_class: MagicMock,
) -> None:
# Given
caplog = get_task_processor_caplog()

@register_task_handler()
def my_function(*args: str, **kwargs: str) -> None:
pass
Expand Down Expand Up @@ -77,11 +65,12 @@ def my_function(*args: str, **kwargs: str) -> None:


def test_register_task_handler_run_in_thread__transaction_commit__false(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
mock_thread_class: MagicMock,
) -> None:
# Given
caplog = get_task_processor_caplog()

@register_task_handler(transaction_on_commit=False)
def my_function(*args, **kwargs):
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

from django.db import DatabaseError
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture

from task_processor.threads import TaskRunner
from tests.unit.task_processor.conftest import GetTaskProcessorCaplog


def test_task_runner_is_resilient_to_database_errors(
db: None,
mocker: MockerFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
settings: SettingsWrapper,
) -> None:
# Given
caplog = get_task_processor_caplog(logging.DEBUG)

task_runner = TaskRunner()
mocker.patch(
"task_processor.threads.run_tasks", side_effect=DatabaseError("Database error")
)

# When
task_runner.run_iteration()

# Then
assert len(caplog.records) == 2

assert caplog.records[0].levelno == logging.ERROR
assert (
caplog.records[0].message
== "Received database error retrieving tasks: Database error."
)

assert caplog.records[1].levelno == logging.DEBUG
assert caplog.records[1].message.startswith("Traceback")