diff --git a/api/app/settings/common.py b/api/app/settings/common.py index 30db5af3714c..9190fa95f2ac 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -517,6 +517,11 @@ "handlers": ["console"], "propagate": False, }, + "webhooks": { + "level": LOG_LEVEL, + "handlers": ["console"], + "propagate": False, + }, }, } diff --git a/api/poetry.lock b/api/poetry.lock index 33765776d655..b8f500fe0f24 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -932,6 +932,20 @@ django = ">=3.2" django-ipware = ">=3" setuptools = "*" +[[package]] +name = "django-capture-on-commit-callbacks" +version = "1.11.0" +description = "Capture and make assertions on transaction.on_commit() callbacks." +optional = false +python-versions = ">=3.7" +files = [ + {file = "django-capture-on-commit-callbacks-1.11.0.tar.gz", hash = "sha256:ee5a79dc74937a0318c192b54d904ce0826ced47748d160bf15324fc77e98c41"}, + {file = "django_capture_on_commit_callbacks-1.11.0-py3-none-any.whl", hash = "sha256:a75300586390411a7e4641128c4251fdc5db25b6e76543329d82fb2c2bc71163"}, +] + +[package.dependencies] +Django = ">=3.2" + [[package]] name = "django-cors-headers" version = "3.5.0" @@ -4327,4 +4341,4 @@ requests = ">=2.7,<3.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f8530649eadf697a33c9d8ee235569bfb3fc9f461ccdb82f8df0607ec069c019" +content-hash = "34d4e3a4f1c1ccd342d97262a2b6f0bc1c293fad2541ab7faffb5594d25ad16f" diff --git a/api/pyproject.toml b/api/pyproject.toml index 342fb432aabd..2cb6b6e538da 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -130,6 +130,7 @@ pytest-cov = "~4.1.0" datamodel-code-generator = "~0.22" requests-mock = "^1.11.0" pdbpp = "^0.10.3" +django-capture-on-commit-callbacks = "^1.11.0" [build-system] requires = ["poetry-core>=1.5.0"] diff --git a/api/task_processor/decorators.py b/api/task_processor/decorators.py index 2f7e5482415d..6819b85d1fd9 100644 --- a/api/task_processor/decorators.py +++ b/api/task_processor/decorators.py @@ -6,6 +6,7 @@ from threading import Thread from django.conf import settings +from django.db.transaction import on_commit from django.utils import timezone from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError @@ -13,100 +14,157 @@ from task_processor.task_registry import register_task from task_processor.task_run_method import TaskRunMethod +P = typing.ParamSpec("P") + logger = logging.getLogger(__name__) -def register_task_handler( - task_name: str = None, - queue_size: int = None, - priority: TaskPriority = TaskPriority.NORMAL, -): - def decorator(f: typing.Callable): - nonlocal task_name +class TaskHandler(typing.Generic[P]): + __slots__ = ( + "unwrapped", + "queue_size", + "priority", + "transaction_on_commit", + "task_identifier", + ) + + unwrapped: typing.Callable[P, None] + + def __init__( + self, + f: typing.Callable[P, None], + *, + task_name: str | None = None, + queue_size: int | None = None, + priority: TaskPriority = TaskPriority.NORMAL, + transaction_on_commit: bool = True, + ) -> None: + self.unwrapped = f + self.queue_size = queue_size + self.priority = priority + self.transaction_on_commit = transaction_on_commit task_name = task_name or f.__name__ task_module = getmodule(f).__name__.rsplit(".")[-1] - task_identifier = f"{task_module}.{task_name}" + self.task_identifier = task_identifier = f"{task_module}.{task_name}" register_task(task_identifier, f) - def delay( - *, - delay_until: datetime = None, - args: typing.Tuple = (), - kwargs: typing.Dict = None, - ) -> typing.Optional[Task]: - logger.debug("Request to run task '%s' asynchronously.", task_identifier) - - kwargs = kwargs or dict() - - if delay_until and settings.TASK_RUN_METHOD != TaskRunMethod.TASK_PROCESSOR: - logger.warning( - "Cannot schedule tasks to run in the future without task processor." + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + _validate_inputs(*args, **kwargs) + return self.unwrapped(*args, **kwargs) + + def delay( + self, + *, + delay_until: datetime | None = None, + # TODO @khvn26 consider typing `args` and `kwargs` with `ParamSpec` + # (will require a change to the signature) + args: tuple[typing.Any] = (), + kwargs: dict[str, typing.Any] | None = None, + ) -> Task | None: + logger.debug("Request to run task '%s' asynchronously.", self.task_identifier) + + kwargs = kwargs or {} + + if delay_until and settings.TASK_RUN_METHOD != TaskRunMethod.TASK_PROCESSOR: + logger.warning( + "Cannot schedule tasks to run in the future without task processor." + ) + return + + if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY: + _validate_inputs(*args, **kwargs) + self.unwrapped(*args, **kwargs) + elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD: + logger.debug("Running task '%s' in separate thread", self.task_identifier) + self.run_in_thread(args=args, kwargs=kwargs) + else: + logger.debug("Creating task for function '%s'...", self.task_identifier) + try: + task = Task.create( + task_identifier=self.task_identifier, + scheduled_for=delay_until or timezone.now(), + priority=self.priority, + queue_size=self.queue_size, + args=args, + kwargs=kwargs, ) + except TaskQueueFullError as e: + logger.warning(e) return - if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY: - _validate_inputs(*args, **kwargs) - f(*args, **kwargs) - elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD: - logger.debug("Running task '%s' in separate thread", task_identifier) - run_in_thread(args=args, kwargs=kwargs) - else: - logger.debug("Creating task for function '%s'...", task_identifier) - try: - task = Task.create( - task_identifier=task_identifier, - scheduled_for=delay_until or timezone.now(), - priority=priority, - queue_size=queue_size, - args=args, - kwargs=kwargs, - ) - except TaskQueueFullError as e: - logger.warning(e) - return - - task.save() - return task - - def run_in_thread(*, args: typing.Tuple = (), kwargs: typing.Dict = None): - logger.info("Running function %s in unmanaged thread.", f.__name__) - _validate_inputs(*args, **kwargs) - Thread(target=f, args=args, kwargs=kwargs, daemon=True).start() - - def _wrapper(*args, **kwargs): - """ - Execute the function after validating the arguments. Ensures that, in unit testing, - the arguments are validated to prevent issues with serialization in an environment - that utilises the task processor. - """ - _validate_inputs(*args, **kwargs) - return f(*args, **kwargs) - - _wrapper.delay = delay - _wrapper.run_in_thread = run_in_thread - _wrapper.task_identifier = task_identifier - - # patch the original unwrapped function onto the wrapped version for testing - _wrapper.unwrapped = f - - return _wrapper + task.save() + return task + + def run_in_thread( + self, + *, + args: tuple[typing.Any] = (), + kwargs: dict[str, typing.Any] | None = None, + ) -> None: + _validate_inputs(*args, **kwargs) + thread = Thread(target=self.unwrapped, args=args, kwargs=kwargs, daemon=True) + + def _start() -> None: + logger.info( + "Running function %s in unmanaged thread.", self.unwrapped.__name__ + ) + thread.start() + + if self.transaction_on_commit: + return on_commit(_start) + return _start() + + +def register_task_handler( # noqa: C901 + *, + task_name: str | None = None, + queue_size: int | None = None, + priority: TaskPriority = TaskPriority.NORMAL, + transaction_on_commit: bool = True, +) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]: + """ + Turn a function into an asynchronous task. + + :param str task_name: task name. Defaults to function name. + :param int queue_size: (`TASK_PROCESSOR` task run method only) + max queue size for the task. Task runs exceeding the max size get dropped by + the task processor Defaults to `None` (infinite). + :param TaskPriority priority: task priority. + :param bool transaction_on_commit: (`SEPARATE_THREAD` task run method only) + Whether to wrap the task call in `transanction.on_commit`. Defaults to `True`. + We need this for the task to be able to access data committed with the current + transaction. If the task is invoked outside of a transaction, it will start + immediately. + Pass `False` if you want the task to start immediately regardless of current + transaction. + :rtype: TaskHandler + """ + + def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]: + return TaskHandler( + f, + task_name=task_name, + queue_size=queue_size, + priority=priority, + transaction_on_commit=transaction_on_commit, + ) - return decorator + return wrapper def register_recurring_task( run_every: timedelta, - task_name: str = None, - args: typing.Tuple = (), - kwargs: typing.Dict = None, - first_run_time: time = None, -): + task_name: str | None = None, + args: tuple[typing.Any] = (), + kwargs: dict[str, typing.Any] | None = None, + first_run_time: time | None = None, +) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]: if not os.environ.get("RUN_BY_PROCESSOR"): # Do not register recurring tasks if not invoked by task processor return lambda f: f - def decorator(f: typing.Callable): + def decorator(f: typing.Callable[..., None]) -> RecurringTask: nonlocal task_name task_name = task_name or f.__name__ @@ -118,8 +176,8 @@ def decorator(f: typing.Callable): task, _ = RecurringTask.objects.update_or_create( task_identifier=task_identifier, defaults={ - "serialized_args": RecurringTask.serialize_data(args or tuple()), - "serialized_kwargs": RecurringTask.serialize_data(kwargs or dict()), + "serialized_args": RecurringTask.serialize_data(args or ()), + "serialized_kwargs": RecurringTask.serialize_data(kwargs or {}), "run_every": run_every, "first_run_time": first_run_time, }, @@ -129,9 +187,9 @@ def decorator(f: typing.Callable): return decorator -def _validate_inputs(*args, **kwargs): +def _validate_inputs(*args: typing.Any, **kwargs: typing.Any) -> None: try: - Task.serialize_data(args or tuple()) - Task.serialize_data(kwargs or dict()) + Task.serialize_data(args or ()) + Task.serialize_data(kwargs or {}) except TypeError as e: raise InvalidArgumentsError("Inputs are not serializable.") from e diff --git a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py index 4b730ff8f176..ba472aed1a40 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -1,9 +1,12 @@ import json import logging from datetime import timedelta +from unittest.mock import MagicMock import pytest +from django_capture_on_commit_callbacks import capture_on_commit_callbacks from pytest_django.fixtures import SettingsWrapper +from pytest_mock import MockerFixture from task_processor.decorators import ( register_recurring_task, @@ -15,8 +18,8 @@ from task_processor.task_run_method import TaskRunMethod -def test_register_task_handler_run_in_thread(mocker, caplog): - # Given +@pytest.fixture +def capture_task_processor_logger(caplog: pytest.LogCaptureFixture) -> None: # caplog doesn't allow you to capture logging outputs from loggers that don't # propagate to root. Quick hack here to get the task_processor logger to # propagate. @@ -27,15 +30,64 @@ def test_register_task_handler_run_in_thread(mocker, caplog): task_processor_logger.setLevel(logging.INFO) caplog.set_level(logging.INFO) + +@pytest.fixture +def mock_thread_class( + mocker: MockerFixture, +) -> MagicMock: + mock_thread_class = mocker.patch( + "task_processor.decorators.Thread", + return_value=mocker.MagicMock(), + ) + return mock_thread_class + + +@pytest.mark.django_db +def test_register_task_handler_run_in_thread__transaction_commit__true__default( + capture_task_processor_logger: None, + caplog: pytest.LogCaptureFixture, + mock_thread_class: MagicMock, +) -> None: + # Given @register_task_handler() - def my_function(*args, **kwargs): + def my_function(*args: str, **kwargs: str) -> None: pass - mock_thread = mocker.MagicMock() - mock_thread_class = mocker.patch( - "task_processor.decorators.Thread", return_value=mock_thread + mock_thread = mock_thread_class.return_value + + args = ("foo",) + kwargs = {"bar": "baz"} + + # When + # TODO Switch to pytest-django's django_capture_on_commit_callbacks + # fixture when migrating to Django 4 + with capture_on_commit_callbacks(execute=True): + my_function.run_in_thread(args=args, kwargs=kwargs) + + # Then + mock_thread_class.assert_called_once_with( + target=my_function.unwrapped, args=args, kwargs=kwargs, daemon=True + ) + mock_thread.start.assert_called_once() + + assert len(caplog.records) == 1 + assert ( + caplog.records[0].message == "Running function my_function in unmanaged thread." ) + +def test_register_task_handler_run_in_thread__transaction_commit__false( + capture_task_processor_logger: None, + caplog: pytest.LogCaptureFixture, + mock_thread_class: MagicMock, +) -> None: + # Given + @register_task_handler(transaction_on_commit=False) + def my_function(*args, **kwargs): + pass + + mock_thread = mock_thread_class.return_value + args = ("foo",) kwargs = {"bar": "baz"} diff --git a/api/webhooks/webhooks.py b/api/webhooks/webhooks.py index 22410122f4bc..5b4e5df20b60 100644 --- a/api/webhooks/webhooks.py +++ b/api/webhooks/webhooks.py @@ -1,5 +1,6 @@ import enum import json +import logging import typing import requests @@ -23,6 +24,8 @@ if typing.TYPE_CHECKING: import environments # noqa +logger = logging.getLogger(__name__) + WebhookModels = typing.Union[OrganisationWebhook, "environments.models.Webhook"] @@ -98,7 +101,13 @@ def _call_webhook( signature = sign_payload(json_data, key=webhook.secret) headers.update({FLAGSMITH_SIGNATURE_HEADER: signature}) - return requests.post(str(webhook.url), data=json_data, headers=headers, timeout=10) + try: + return requests.post( + str(webhook.url), data=json_data, headers=headers, timeout=10 + ) + except requests.exceptions.RequestException as exc: + logger.debug("Error calling webhook", exc_info=exc) + raise def _call_webhook_email_on_error(