Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Audit Log records don't get created with threaded task processing #2958

Merged
merged 6 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,11 @@
"handlers": ["console"],
"propagate": False,
},
"webhooks": {
"level": LOG_LEVEL,
"handlers": ["console"],
"propagate": False,
},
},
}

Expand Down
16 changes: 15 additions & 1 deletion api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pytest-cov = "~4.1.0"
datamodel-code-generator = "~0.22"
requests-mock = "^1.11.0"
pdbpp = "^0.10.3"
django-capture-on-commit-callbacks = "^1.11.0"

[build-system]
requires = ["poetry-core>=1.5.0"]
Expand Down
211 changes: 132 additions & 79 deletions api/task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,107 +6,160 @@
from threading import Thread

from django.conf import settings
from django.db.transaction import on_commit
from django.utils import timezone

from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod

P = typing.ParamSpec("P")

logger = logging.getLogger(__name__)


def register_task_handler(
task_name: str = None,
queue_size: int = None,
priority: TaskPriority = TaskPriority.NORMAL,
):
def decorator(f: typing.Callable):
nonlocal task_name
class TaskHandler(typing.Generic[P]):
__slots__ = (
"unwrapped",
"queue_size",
"priority",
"transaction_on_commit",
"task_identifier",
)

unwrapped: typing.Callable[P, None]

def __init__(
self,
f: typing.Callable[P, None],
*,
task_name: str | None = None,
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
task_identifier = f"{task_module}.{task_name}"
self.task_identifier = task_identifier = f"{task_module}.{task_name}"
register_task(task_identifier, f)

def delay(
*,
delay_until: datetime = None,
args: typing.Tuple = (),
kwargs: typing.Dict = None,
) -> typing.Optional[Task]:
logger.debug("Request to run task '%s' asynchronously.", task_identifier)

kwargs = kwargs or dict()

if delay_until and settings.TASK_RUN_METHOD != TaskRunMethod.TASK_PROCESSOR:
logger.warning(
"Cannot schedule tasks to run in the future without task processor."
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
_validate_inputs(*args, **kwargs)
return self.unwrapped(*args, **kwargs)

def delay(
self,
*,
delay_until: datetime | None = None,
# TODO @khvn26 consider typing `args` and `kwargs` with `ParamSpec`
# (will require a change to the signature)
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
) -> Task | None:
logger.debug("Request to run task '%s' asynchronously.", self.task_identifier)

kwargs = kwargs or {}

if delay_until and settings.TASK_RUN_METHOD != TaskRunMethod.TASK_PROCESSOR:
logger.warning(
"Cannot schedule tasks to run in the future without task processor."
)
return

if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY:
_validate_inputs(*args, **kwargs)
self.unwrapped(*args, **kwargs)
elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD:
logger.debug("Running task '%s' in separate thread", self.task_identifier)
self.run_in_thread(args=args, kwargs=kwargs)
else:
logger.debug("Creating task for function '%s'...", self.task_identifier)
try:
task = Task.create(
task_identifier=self.task_identifier,
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
args=args,
kwargs=kwargs,
)
except TaskQueueFullError as e:
logger.warning(e)
return

if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY:
_validate_inputs(*args, **kwargs)
f(*args, **kwargs)
elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD:
logger.debug("Running task '%s' in separate thread", task_identifier)
run_in_thread(args=args, kwargs=kwargs)
else:
logger.debug("Creating task for function '%s'...", task_identifier)
try:
task = Task.create(
task_identifier=task_identifier,
scheduled_for=delay_until or timezone.now(),
priority=priority,
queue_size=queue_size,
args=args,
kwargs=kwargs,
)
except TaskQueueFullError as e:
logger.warning(e)
return

task.save()
return task

def run_in_thread(*, args: typing.Tuple = (), kwargs: typing.Dict = None):
logger.info("Running function %s in unmanaged thread.", f.__name__)
_validate_inputs(*args, **kwargs)
Thread(target=f, args=args, kwargs=kwargs, daemon=True).start()

def _wrapper(*args, **kwargs):
"""
Execute the function after validating the arguments. Ensures that, in unit testing,
the arguments are validated to prevent issues with serialization in an environment
that utilises the task processor.
"""
_validate_inputs(*args, **kwargs)
return f(*args, **kwargs)

_wrapper.delay = delay
_wrapper.run_in_thread = run_in_thread
_wrapper.task_identifier = task_identifier

# patch the original unwrapped function onto the wrapped version for testing
_wrapper.unwrapped = f

return _wrapper
task.save()
return task

def run_in_thread(
self,
*,
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
) -> None:
_validate_inputs(*args, **kwargs)
thread = Thread(target=self.unwrapped, args=args, kwargs=kwargs, daemon=True)

def _start() -> None:
logger.info(
"Running function %s in unmanaged thread.", self.unwrapped.__name__
)
thread.start()

if self.transaction_on_commit:
return on_commit(_start)
return _start()


def register_task_handler( # noqa: C901
*,
task_name: str | None = None,
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.

:param str task_name: task name. Defaults to function name.
:param int queue_size: (`TASK_PROCESSOR` task run method only)
max queue size for the task. Task runs exceeding the max size get dropped by
the task processor Defaults to `None` (infinite).
:param TaskPriority priority: task priority.
:param bool transaction_on_commit: (`SEPARATE_THREAD` task run method only)
Whether to wrap the task call in `transanction.on_commit`. Defaults to `True`.
:rtype: TaskProtocol
"""

def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
return TaskHandler(
f,
task_name=task_name,
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
)

return decorator
return wrapper


def register_recurring_task(
run_every: timedelta,
task_name: str = None,
args: typing.Tuple = (),
kwargs: typing.Dict = None,
first_run_time: time = None,
):
task_name: str | None = None,
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
return lambda f: f

def decorator(f: typing.Callable):
def decorator(f: typing.Callable[..., None]) -> RecurringTask:
nonlocal task_name

task_name = task_name or f.__name__
Expand All @@ -118,8 +171,8 @@ def decorator(f: typing.Callable):
task, _ = RecurringTask.objects.update_or_create(
task_identifier=task_identifier,
defaults={
"serialized_args": RecurringTask.serialize_data(args or tuple()),
"serialized_kwargs": RecurringTask.serialize_data(kwargs or dict()),
"serialized_args": RecurringTask.serialize_data(args or ()),
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
},
Expand All @@ -129,9 +182,9 @@ def decorator(f: typing.Callable):
return decorator


def _validate_inputs(*args, **kwargs):
def _validate_inputs(*args: typing.Any, **kwargs: typing.Any) -> None:
try:
Task.serialize_data(args or tuple())
Task.serialize_data(kwargs or dict())
Task.serialize_data(args or ())
Task.serialize_data(kwargs or {})
except TypeError as e:
raise InvalidArgumentsError("Inputs are not serializable.") from e
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import logging
from datetime import timedelta
from unittest.mock import MagicMock

import pytest
from django_capture_on_commit_callbacks import capture_on_commit_callbacks
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture

from task_processor.decorators import (
register_recurring_task,
Expand All @@ -15,8 +18,8 @@
from task_processor.task_run_method import TaskRunMethod


def test_register_task_handler_run_in_thread(mocker, caplog):
# Given
@pytest.fixture
def capture_task_processor_logger(caplog: pytest.LogCaptureFixture) -> None:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
Expand All @@ -27,15 +30,64 @@ def test_register_task_handler_run_in_thread(mocker, caplog):
task_processor_logger.setLevel(logging.INFO)
caplog.set_level(logging.INFO)


@pytest.fixture
def mock_thread_class(
mocker: MockerFixture,
) -> MagicMock:
mock_thread_class = mocker.patch(
"task_processor.decorators.Thread",
return_value=mocker.MagicMock(),
)
return mock_thread_class


@pytest.mark.django_db
def test_register_task_handler_run_in_thread__transaction_commit__true__default(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
mock_thread_class: MagicMock,
) -> None:
# Given
@register_task_handler()
def my_function(*args, **kwargs):
def my_function(*args: str, **kwargs: str) -> None:
pass

mock_thread = mocker.MagicMock()
mock_thread_class = mocker.patch(
"task_processor.decorators.Thread", return_value=mock_thread
mock_thread = mock_thread_class.return_value

args = ("foo",)
kwargs = {"bar": "baz"}

# When
# TODO Switch to pytest-django's django_capture_on_commit_callbacks
# fixture when migrating to Django 4
with capture_on_commit_callbacks(execute=True):
my_function.run_in_thread(args=args, kwargs=kwargs)

# Then
mock_thread_class.assert_called_once_with(
target=my_function.unwrapped, args=args, kwargs=kwargs, daemon=True
)
mock_thread.start.assert_called_once()

assert len(caplog.records) == 1
assert (
caplog.records[0].message == "Running function my_function in unmanaged thread."
)


def test_register_task_handler_run_in_thread__transaction_commit__false(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
mock_thread_class: MagicMock,
) -> None:
# Given
@register_task_handler(transaction_on_commit=False)
def my_function(*args, **kwargs):
pass

mock_thread = mock_thread_class.return_value

args = ("foo",)
kwargs = {"bar": "baz"}

Expand Down
Loading