From d947474696a3a213ff196ffc5ac3bf802dbd8062 Mon Sep 17 00:00:00 2001 From: Matthew Elwell Date: Wed, 6 Sep 2023 16:22:03 +0100 Subject: [PATCH] feat(task-processor): validate arguments passed to task processor functions (#2747) --- api/task_processor/decorators.py | 31 ++++++++++++-- api/task_processor/exceptions.py | 4 ++ api/tests/unit/audit/test_unit_audit_tasks.py | 2 +- .../test_unit_task_processor_decorators.py | 40 ++++++++++++++++++- 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/api/task_processor/decorators.py b/api/task_processor/decorators.py index 5e4886af94e6..76c1d639b1cb 100644 --- a/api/task_processor/decorators.py +++ b/api/task_processor/decorators.py @@ -8,6 +8,7 @@ from django.conf import settings from django.utils import timezone +from task_processor.exceptions import InvalidArgumentsError from task_processor.models import RecurringTask, Task from task_processor.task_registry import register_task from task_processor.task_run_method import TaskRunMethod @@ -41,6 +42,7 @@ def delay( return if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY: + _validate_inputs(*args, **kwargs) f(*args, **kwargs) elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD: logger.debug("Running task '%s' in separate thread", task_identifier) @@ -58,13 +60,26 @@ def delay( def run_in_thread(*, args: typing.Tuple = (), kwargs: typing.Dict = None): logger.info("Running function %s in unmanaged thread.", f.__name__) + _validate_inputs(*args, **kwargs) Thread(target=f, args=args, kwargs=kwargs, daemon=True).start() - f.delay = delay - f.run_in_thread = run_in_thread - f.task_identifier = task_identifier + def _wrapper(*args, **kwargs): + """ + Execute the function after validating the arguments. Ensures that, in unit testing, + the arguments are validated to prevent issues with serialization in an environment + that utilises the task processor. + """ + _validate_inputs(*args, **kwargs) + return f(*args, **kwargs) - return f + _wrapper.delay = delay + _wrapper.run_in_thread = run_in_thread + _wrapper.task_identifier = task_identifier + + # patch the original unwrapped function onto the wrapped version for testing + _wrapper.unwrapped = f + + return _wrapper return decorator @@ -101,3 +116,11 @@ def decorator(f: typing.Callable): return task return decorator + + +def _validate_inputs(*args, **kwargs): + try: + Task.serialize_data(args or tuple()) + Task.serialize_data(kwargs or dict()) + except TypeError as e: + raise InvalidArgumentsError("Inputs are not serializable.") from e diff --git a/api/task_processor/exceptions.py b/api/task_processor/exceptions.py index 4cbb4e5bfe8d..12cf27f73a7e 100644 --- a/api/task_processor/exceptions.py +++ b/api/task_processor/exceptions.py @@ -1,2 +1,6 @@ class TaskProcessingError(Exception): pass + + +class InvalidArgumentsError(TaskProcessingError): + pass diff --git a/api/tests/unit/audit/test_unit_audit_tasks.py b/api/tests/unit/audit/test_unit_audit_tasks.py index fe500fcee6f7..e2d5d11944df 100644 --- a/api/tests/unit/audit/test_unit_audit_tasks.py +++ b/api/tests/unit/audit/test_unit_audit_tasks.py @@ -178,7 +178,7 @@ def test_create_segment_priorities_changed_audit_log( create_segment_priorities_changed_audit_log( previous_id_priority_pairs=[ (feature_segment.id, 0), - (another_feature_segment, 1), + (another_feature_segment.id, 1), ], feature_segment_ids=[feature_segment.id, another_feature_segment.id], user_id=admin_user.id, diff --git a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py index 24ef642bdd4c..8137e32cf2de 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -3,13 +3,16 @@ from datetime import timedelta import pytest +from pytest_django.fixtures import SettingsWrapper from task_processor.decorators import ( register_recurring_task, register_task_handler, ) +from task_processor.exceptions import InvalidArgumentsError from task_processor.models import RecurringTask from task_processor.task_registry import get_task +from task_processor.task_run_method import TaskRunMethod def test_register_task_handler_run_in_thread(mocker, caplog): @@ -41,7 +44,7 @@ def my_function(*args, **kwargs): # Then mock_thread_class.assert_called_once_with( - target=my_function, args=args, kwargs=kwargs, daemon=True + target=my_function.unwrapped, args=args, kwargs=kwargs, daemon=True ) mock_thread.start.assert_called_once() @@ -93,3 +96,38 @@ def some_function(first_arg, second_arg): assert not RecurringTask.objects.filter(task_identifier=task_identifier).exists() with pytest.raises(KeyError): assert get_task(task_identifier) + + +def test_register_task_handler_validates_inputs() -> None: + # Given + @register_task_handler() + def my_function(*args, **kwargs): + pass + + class NonSerializableObj: + pass + + # When + with pytest.raises(InvalidArgumentsError): + my_function(NonSerializableObj()) + + +@pytest.mark.parametrize( + "task_run_method", (TaskRunMethod.SEPARATE_THREAD, TaskRunMethod.SYNCHRONOUSLY) +) +def test_inputs_are_validated_when_run_without_task_processor( + settings: SettingsWrapper, task_run_method: TaskRunMethod +) -> None: + # Given + settings.TASK_RUN_METHOD = task_run_method + + @register_task_handler() + def my_function(*args, **kwargs): + pass + + class NonSerializableObj: + pass + + # When + with pytest.raises(InvalidArgumentsError): + my_function.delay(args=(NonSerializableObj(),))