From 6830ef666c7f9931a4d4edfeef9e58e7d2768dcc Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Tue, 31 Oct 2023 15:50:13 +0530 Subject: [PATCH] feat(task-processor): Add priority support (#2847) --- api/audit/tasks.py | 7 +- api/core/migration_helpers.py | 15 +++- .../identities/edge_request_forwarder.py | 7 +- api/edge_api/identities/tasks.py | 3 +- api/environments/tasks.py | 5 +- api/task_processor/decorators.py | 13 ++-- .../0008_add_get_task_to_process_function.py | 4 +- .../migrations/0010_task_priority.py | 18 +++++ ...11_add_priority_to_get_tasks_to_process.py | 27 ++++++++ ...> 0008_get_recurring_tasks_to_process.sql} | 0 ...cess.sql => 0008_get_tasks_to_process.sql} | 0 .../sql/0011_get_tasks_to_process.sql | 30 ++++++++ api/task_processor/models.py | 55 ++++++++------- .../core/test_unit_core_migration_helpers.py | 68 ++++++++++++++++++- .../test_unit_task_processor_decorators.py | 42 +++++++++++- .../test_unit_task_processor_models.py | 47 ++----------- .../test_unit_task_processor_processor.py | 51 +++++++++++--- 17 files changed, 294 insertions(+), 98 deletions(-) create mode 100644 api/task_processor/migrations/0010_task_priority.py create mode 100644 api/task_processor/migrations/0011_add_priority_to_get_tasks_to_process.py rename api/task_processor/migrations/sql/{get_recurring_tasks_to_process.sql => 0008_get_recurring_tasks_to_process.sql} (100%) rename api/task_processor/migrations/sql/{get_tasks_to_process.sql => 0008_get_tasks_to_process.sql} (100%) create mode 100644 api/task_processor/migrations/sql/0011_get_tasks_to_process.sql diff --git a/api/audit/tasks.py b/api/audit/tasks.py index 2418071ec91c..0b67e6c9ea27 100644 --- a/api/audit/tasks.py +++ b/api/audit/tasks.py @@ -9,18 +9,19 @@ ) from audit.models import AuditLog, RelatedObjectType from task_processor.decorators import register_task_handler +from task_processor.models import TaskPriority logger = logging.getLogger(__name__) -@register_task_handler() +@register_task_handler(priority=TaskPriority.HIGHEST) def create_feature_state_went_live_audit_log(feature_state_id: int): _create_feature_state_audit_log_for_change_request( feature_state_id, FEATURE_STATE_WENT_LIVE_MESSAGE ) -@register_task_handler() +@register_task_handler(priority=TaskPriority.HIGHEST) def create_feature_state_updated_by_change_request_audit_log(feature_state_id: int): _create_feature_state_audit_log_for_change_request( feature_state_id, FEATURE_STATE_UPDATED_BY_CHANGE_REQUEST_MESSAGE @@ -57,7 +58,7 @@ def _create_feature_state_audit_log_for_change_request( ) -@register_task_handler() +@register_task_handler(priority=TaskPriority.HIGHEST) def create_audit_log_from_historical_record( history_instance_id: int, history_user_id: typing.Optional[int], diff --git a/api/core/migration_helpers.py b/api/core/migration_helpers.py index 213065b5a22a..8e773f252695 100644 --- a/api/core/migration_helpers.py +++ b/api/core/migration_helpers.py @@ -1,5 +1,7 @@ +import os import typing import uuid +from contextlib import suppress from django.db import migrations @@ -8,9 +10,16 @@ class PostgresOnlyRunSQL(migrations.RunSQL): @classmethod - def from_sql_file(cls, file_path: str, reverse_sql: str) -> "PostgresOnlyRunSQL": - with open(file_path) as f: - return cls(f.read(), reverse_sql=reverse_sql) + def from_sql_file( + cls, + file_path: typing.Union[str, os.PathLike], + reverse_sql: typing.Union[str, os.PathLike] = None, + ) -> "PostgresOnlyRunSQL": + with open(file_path) as forward_sql: + with suppress(FileNotFoundError, TypeError): + with open(reverse_sql) as reverse_sql_file: + reverse_sql = reverse_sql_file.read() + return cls(forward_sql.read(), reverse_sql=reverse_sql) def database_forwards(self, app_label, schema_editor, from_state, to_state): if schema_editor.connection.vendor != "postgresql": diff --git a/api/edge_api/identities/edge_request_forwarder.py b/api/edge_api/identities/edge_request_forwarder.py index 42e7f64696b6..a506a489a480 100644 --- a/api/edge_api/identities/edge_request_forwarder.py +++ b/api/edge_api/identities/edge_request_forwarder.py @@ -7,6 +7,7 @@ from environments.dynamodb.migrator import IdentityMigrator from task_processor.decorators import register_task_handler +from task_processor.models import TaskPriority def _should_forward(project_id: int) -> bool: @@ -14,7 +15,7 @@ def _should_forward(project_id: int) -> bool: return bool(migrator.is_migration_done) -@register_task_handler(queue_size=2000) +@register_task_handler(queue_size=2000, priority=TaskPriority.LOW) def forward_identity_request( request_method: str, headers: dict, @@ -35,7 +36,7 @@ def forward_identity_request( requests.get(url, params=query_params, headers=headers, timeout=5) -@register_task_handler(queue_size=2000) +@register_task_handler(queue_size=2000, priority=TaskPriority.LOW) def forward_trait_request( request_method: str, headers: dict, @@ -61,7 +62,7 @@ def forward_trait_request_sync( ) -@register_task_handler(queue_size=1000) +@register_task_handler(queue_size=1000, priority=TaskPriority.LOW) def forward_trait_requests( request_method: str, headers: str, diff --git a/api/edge_api/identities/tasks.py b/api/edge_api/identities/tasks.py index dc41da83f406..918352e79e41 100644 --- a/api/edge_api/identities/tasks.py +++ b/api/edge_api/identities/tasks.py @@ -6,6 +6,7 @@ from environments.models import Environment, Webhook from features.models import Feature, FeatureState from task_processor.decorators import register_task_handler +from task_processor.models import TaskPriority from users.models import FFAdminUser from webhooks.webhooks import WebhookEventType, call_environment_webhooks @@ -71,7 +72,7 @@ def call_environment_webhook_for_feature_state_change( call_environment_webhooks(environment, data, event_type=event_type) -@register_task_handler() +@register_task_handler(priority=TaskPriority.HIGH) def sync_identity_document_features(identity_uuid: str): from .models import EdgeIdentity diff --git a/api/environments/tasks.py b/api/environments/tasks.py index 0f2daa3d91ce..6197b74ed3db 100644 --- a/api/environments/tasks.py +++ b/api/environments/tasks.py @@ -6,9 +6,10 @@ send_environment_update_message_for_project, ) from task_processor.decorators import register_task_handler +from task_processor.models import TaskPriority -@register_task_handler() +@register_task_handler(priority=TaskPriority.HIGH) def rebuild_environment_document(environment_id: int): wrapper = DynamoEnvironmentWrapper() if wrapper.is_enabled: @@ -16,7 +17,7 @@ def rebuild_environment_document(environment_id: int): wrapper.write_environment(environment) -@register_task_handler() +@register_task_handler(priority=TaskPriority.HIGHEST) def process_environment_update(audit_log_id: int): audit_log = AuditLog.objects.get(id=audit_log_id) diff --git a/api/task_processor/decorators.py b/api/task_processor/decorators.py index a942757960b6..2f7e5482415d 100644 --- a/api/task_processor/decorators.py +++ b/api/task_processor/decorators.py @@ -9,14 +9,18 @@ from django.utils import timezone from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError -from task_processor.models import RecurringTask, Task +from task_processor.models import RecurringTask, Task, TaskPriority from task_processor.task_registry import register_task from task_processor.task_run_method import TaskRunMethod logger = logging.getLogger(__name__) -def register_task_handler(task_name: str = None, queue_size: int = None): +def register_task_handler( + task_name: str = None, + queue_size: int = None, + priority: TaskPriority = TaskPriority.NORMAL, +): def decorator(f: typing.Callable): nonlocal task_name @@ -50,9 +54,10 @@ def delay( else: logger.debug("Creating task for function '%s'...", task_identifier) try: - task = Task.schedule_task( - schedule_for=delay_until or timezone.now(), + task = Task.create( task_identifier=task_identifier, + scheduled_for=delay_until or timezone.now(), + priority=priority, queue_size=queue_size, args=args, kwargs=kwargs, diff --git a/api/task_processor/migrations/0008_add_get_task_to_process_function.py b/api/task_processor/migrations/0008_add_get_task_to_process_function.py index a30425076697..49d047af124c 100644 --- a/api/task_processor/migrations/0008_add_get_task_to_process_function.py +++ b/api/task_processor/migrations/0008_add_get_task_to_process_function.py @@ -16,7 +16,7 @@ class Migration(migrations.Migration): os.path.join( os.path.dirname(__file__), "sql", - "get_tasks_to_process.sql", + "0008_get_tasks_to_process.sql", ), reverse_sql="DROP FUNCTION IF EXISTS get_tasks_to_process", ), @@ -24,7 +24,7 @@ class Migration(migrations.Migration): os.path.join( os.path.dirname(__file__), "sql", - "get_recurring_tasks_to_process.sql", + "0008_get_recurring_tasks_to_process.sql", ), reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process", ), diff --git a/api/task_processor/migrations/0010_task_priority.py b/api/task_processor/migrations/0010_task_priority.py new file mode 100644 index 000000000000..e7fb473b5284 --- /dev/null +++ b/api/task_processor/migrations/0010_task_priority.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2023-10-13 06:04 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('task_processor', '0009_add_recurring_task_run_first_run_at'), + ] + + operations = [ + migrations.AddField( + model_name='task', + name='priority', + field=models.PositiveSmallIntegerField(choices=[(100, 'Lower'), (75, 'Low'), (50, 'Normal'), (25, 'High'), (0, 'Highest')], default=None, null=True), + ), + ] diff --git a/api/task_processor/migrations/0011_add_priority_to_get_tasks_to_process.py b/api/task_processor/migrations/0011_add_priority_to_get_tasks_to_process.py new file mode 100644 index 000000000000..48f2ed8f6703 --- /dev/null +++ b/api/task_processor/migrations/0011_add_priority_to_get_tasks_to_process.py @@ -0,0 +1,27 @@ +# Generated by Django 3.2.20 on 2023-10-13 04:44 + +from django.db import migrations + +from core.migration_helpers import PostgresOnlyRunSQL +import os + + +class Migration(migrations.Migration): + dependencies = [ + ("task_processor", "0010_task_priority"), + ] + + operations = [ + PostgresOnlyRunSQL.from_sql_file( + os.path.join( + os.path.dirname(__file__), + "sql", + "0011_get_tasks_to_process.sql", + ), + reverse_sql=os.path.join( + os.path.dirname(__file__), + "sql", + "0008_get_tasks_to_process.sql", + ), + ), + ] diff --git a/api/task_processor/migrations/sql/get_recurring_tasks_to_process.sql b/api/task_processor/migrations/sql/0008_get_recurring_tasks_to_process.sql similarity index 100% rename from api/task_processor/migrations/sql/get_recurring_tasks_to_process.sql rename to api/task_processor/migrations/sql/0008_get_recurring_tasks_to_process.sql diff --git a/api/task_processor/migrations/sql/get_tasks_to_process.sql b/api/task_processor/migrations/sql/0008_get_tasks_to_process.sql similarity index 100% rename from api/task_processor/migrations/sql/get_tasks_to_process.sql rename to api/task_processor/migrations/sql/0008_get_tasks_to_process.sql diff --git a/api/task_processor/migrations/sql/0011_get_tasks_to_process.sql b/api/task_processor/migrations/sql/0011_get_tasks_to_process.sql new file mode 100644 index 000000000000..2dc6d60a3673 --- /dev/null +++ b/api/task_processor/migrations/sql/0011_get_tasks_to_process.sql @@ -0,0 +1,30 @@ +CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer) +RETURNS SETOF task_processor_task AS $$ +DECLARE + row_to_return task_processor_task; +BEGIN + -- Select the tasks that needs to be processed + FOR row_to_return IN + SELECT * + FROM task_processor_task + WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND is_locked = FALSE + ORDER BY priority ASC, scheduled_for ASC, created_at ASC + LIMIT num_tasks + -- Select for update to ensure that no other workers can select these tasks while in this transaction block + FOR UPDATE SKIP LOCKED + LOOP + -- Lock every selected task(by updating `is_locked` to true) + UPDATE task_processor_task + -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this + -- transaction is complete (but the tasks are still being executed by the current worker) + SET is_locked = TRUE + WHERE id = row_to_return.id; + -- If we don't explicitly update the `is_locked` column here, the client will receive the row that is actually locked but has the `is_locked` value set to `False`. + row_to_return.is_locked := TRUE; + RETURN NEXT row_to_return; + END LOOP; + + RETURN; +END; +$$ LANGUAGE plpgsql + diff --git a/api/task_processor/models.py b/api/task_processor/models.py index 3ef982c806b5..865954dcd1f7 100644 --- a/api/task_processor/models.py +++ b/api/task_processor/models.py @@ -11,6 +11,14 @@ from task_processor.task_registry import registered_tasks +class TaskPriority(models.IntegerChoices): + LOWER = 100 + LOW = 75 + NORMAL = 50 + HIGH = 25 + HIGHEST = 0 + + class AbstractBaseTask(models.Model): uuid = models.UUIDField(unique=True, default=uuid.uuid4) created_at = models.DateTimeField(auto_now_add=True) @@ -74,6 +82,9 @@ class Task(AbstractBaseTask): num_failures = models.IntegerField(default=0) completed = models.BooleanField(default=False) objects = TaskManager() + priority = models.PositiveSmallIntegerField( + default=None, null=True, choices=TaskPriority.choices + ) class Meta: # We have customised the migration in 0004 to only apply this change to postgres databases @@ -90,44 +101,36 @@ class Meta: def create( cls, task_identifier: str, + scheduled_for: datetime, + priority: TaskPriority = TaskPriority.NORMAL, + queue_size: int = None, *, args: typing.Tuple[typing.Any] = None, kwargs: typing.Dict[str, typing.Any] = None, ) -> "Task": + if queue_size and cls._is_queue_full(task_identifier, queue_size): + raise TaskQueueFullError( + f"Queue for task {task_identifier} is full. " + f"Max queue size is {queue_size}" + ) return Task( task_identifier=task_identifier, + scheduled_for=scheduled_for, + priority=priority, serialized_args=cls.serialize_data(args or tuple()), serialized_kwargs=cls.serialize_data(kwargs or dict()), ) @classmethod - def schedule_task( - cls, - schedule_for: datetime, - task_identifier: str, - queue_size: typing.Optional[int], - *, - args: typing.Tuple[typing.Any] = None, - kwargs: typing.Dict[str, typing.Any] = None, - ) -> "Task": - if queue_size: - if ( - cls.objects.filter( - task_identifier=task_identifier, completed=False, num_failures__lt=3 - ).count() - > queue_size - ): - raise TaskQueueFullError( - f"Queue for task {task_identifier} is full. " - f"Max queue size is {queue_size}" - ) - task = cls.create( - task_identifier=task_identifier, - args=args, - kwargs=kwargs, + def _is_queue_full(cls, task_identifier: str, queue_size: int) -> bool: + return ( + cls.objects.filter( + task_identifier=task_identifier, + completed=False, + num_failures__lt=3, + ).count() + > queue_size ) - task.scheduled_for = schedule_for - return task def mark_failure(self): super().mark_failure() diff --git a/api/tests/unit/core/test_unit_core_migration_helpers.py b/api/tests/unit/core/test_unit_core_migration_helpers.py index 4ec3ff5dc45f..ce52062d82e2 100644 --- a/api/tests/unit/core/test_unit_core_migration_helpers.py +++ b/api/tests/unit/core/test_unit_core_migration_helpers.py @@ -1,4 +1,4 @@ -from core.migration_helpers import AddDefaultUUIDs +from core.migration_helpers import AddDefaultUUIDs, PostgresOnlyRunSQL def test_add_default_uuids_class_correctly_sets_uuid_attribute(mocker): @@ -27,3 +27,69 @@ def test_add_default_uuids_class_correctly_sets_uuid_attribute(mocker): mock_model_class.objects.bulk_update.assert_called_once_with( [mock_model_object], fields=["uuid"] ) + + +def test_postgres_only_run_sql__from_sql_file__with_reverse_sql_as_string( + mocker, tmp_path +): + # Given + forward_sql = "SELECT 1;" + reverse_sql = "SELECT 2;" + + # Create a temporary file + sql_file = tmp_path / "forward_test.sql" + sql_file.write_text(forward_sql) + + mocked_init = mocker.patch( + "core.migration_helpers.PostgresOnlyRunSQL.__init__", return_value=None + ) + + # When + PostgresOnlyRunSQL.from_sql_file(sql_file, reverse_sql) + + # Then + mocked_init.assert_called_once_with(forward_sql, reverse_sql=reverse_sql) + + +def test_postgres_only_run_sql__from_sql_file__with_reverse_sql_as_file_path( + mocker, tmp_path +): + # Given + forward_sql = "SELECT 1;" + reverse_sql = "SELECT 2;" + + # Create temporary files + forward_sql_file = tmp_path / "forward_test.sql" + forward_sql_file.write_text(forward_sql) + + reverse_sql_file = tmp_path / "reverse_test.sql" + reverse_sql_file.write_text(reverse_sql) + + mocked_init = mocker.patch( + "core.migration_helpers.PostgresOnlyRunSQL.__init__", return_value=None + ) + + # When + PostgresOnlyRunSQL.from_sql_file(forward_sql_file, reverse_sql_file) + + # Then + mocked_init.assert_called_once_with(forward_sql, reverse_sql=reverse_sql) + + +def test_postgres_only_run_sql__from_sql_file__without_reverse_sql(mocker, tmp_path): + # Given + forward_sql = "SELECT 1;" + + # Create temporary files + forward_sql_file = tmp_path / "forward_test.sql" + forward_sql_file.write_text(forward_sql) + + mocked_init = mocker.patch( + "core.migration_helpers.PostgresOnlyRunSQL.__init__", return_value=None + ) + + # When + PostgresOnlyRunSQL.from_sql_file(forward_sql_file) + + # Then + mocked_init.assert_called_once_with(forward_sql, reverse_sql=None) diff --git a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py index 8137e32cf2de..4b730ff8f176 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -10,7 +10,7 @@ register_task_handler, ) from task_processor.exceptions import InvalidArgumentsError -from task_processor.models import RecurringTask +from task_processor.models import RecurringTask, Task, TaskPriority from task_processor.task_registry import get_task from task_processor.task_run_method import TaskRunMethod @@ -131,3 +131,43 @@ class NonSerializableObj: # When with pytest.raises(InvalidArgumentsError): my_function.delay(args=(NonSerializableObj(),)) + + +def test_delay_returns_none_if_task_queue_is_full(settings, db): + # Given + settings.TASK_RUN_METHOD = TaskRunMethod.TASK_PROCESSOR + + @register_task_handler(queue_size=1) + def my_function(*args, **kwargs): + pass + + for _ in range(10): + Task.objects.create( + task_identifier="test_unit_task_processor_decorators.my_function" + ) + + # When + task = my_function.delay() + + # Then + assert task is None + + +def test_can_create_task_with_priority(settings, db): + # Given + settings.TASK_RUN_METHOD = TaskRunMethod.TASK_PROCESSOR + + @register_task_handler(priority=TaskPriority.HIGH) + def my_function(*args, **kwargs): + pass + + for _ in range(10): + Task.objects.create( + task_identifier="test_unit_task_processor_decorators.my_function" + ) + + # When + task = my_function.delay() + + # Then + assert task.priority == TaskPriority.HIGH diff --git a/api/tests/unit/task_processor/test_unit_task_processor_models.py b/api/tests/unit/task_processor/test_unit_task_processor_models.py index 989a4a3573cd..4b4b02018741 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_models.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_models.py @@ -5,7 +5,6 @@ from django.utils import timezone from task_processor.decorators import register_task_handler -from task_processor.exceptions import TaskQueueFullError from task_processor.models import RecurringTask, Task now = timezone.now() @@ -24,7 +23,12 @@ def test_task_run(): args = ["foo"] kwargs = {"arg_two": "bar"} - task = Task.create(my_callable.task_identifier, args=args, kwargs=kwargs) + task = Task.create( + my_callable.task_identifier, + scheduled_for=timezone.now(), + args=args, + kwargs=kwargs, + ) # When result = task.run() @@ -55,42 +59,3 @@ def test_recurring_task_run_should_execute_first_run_at(first_run_time, expected ).should_execute == expected ) - - -def test_schedule_task_raises_error_if_queue_is_full(db): - # Given - task_identifier = "my_callable" - - # some incomplete task - for _ in range(10): - Task.objects.create(task_identifier=task_identifier) - - # When - with pytest.raises(TaskQueueFullError): - Task.schedule_task( - schedule_for=timezone.now(), task_identifier=task_identifier, queue_size=9 - ) - - -def test_can_schedule_task_raises_error_if_queue_is_not_full(db): - # Given - task_identifier = "my_callable" - - # Some incomplete task - for _ in range(10): - Task.objects.create(task_identifier=task_identifier) - - # tasks with different identifiers - Task.objects.create(task_identifier="task_with_different_identifier") - - # failed tasks - Task.objects.create( - task_identifier="task_with_different_identifier", num_failures=3 - ) - - # When - task = Task.schedule_task( - schedule_for=timezone.now(), task_identifier=task_identifier, queue_size=10 - ) - # Then - assert task is not None diff --git a/api/tests/unit/task_processor/test_unit_task_processor_processor.py b/api/tests/unit/task_processor/test_unit_task_processor_processor.py index 154daedc4bbc..e954956f4651 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -15,6 +15,7 @@ RecurringTask, RecurringTaskRun, Task, + TaskPriority, TaskResult, TaskRun, ) @@ -25,7 +26,11 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db): # Given organisation_name = f"test-org-{uuid.uuid4()}" - task = Task.create(_create_organisation.task_identifier, args=(organisation_name,)) + task = Task.create( + _create_organisation.task_identifier, + scheduled_for=timezone.now(), + args=(organisation_name,), + ) task.save() # When @@ -168,7 +173,7 @@ def _a_task(): def test_run_task_runs_task_and_creates_task_run_object_when_failure(db): # Given - task = Task.create(_raise_exception.task_identifier) + task = Task.create(_raise_exception.task_identifier, scheduled_for=timezone.now()) task.save() # When @@ -188,7 +193,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure(db): def test_run_task_runs_failed_task_again(db): # Given - task = Task.create(_raise_exception.task_identifier) + task = Task.create(_raise_exception.task_identifier, scheduled_for=timezone.now()) task.save() # When @@ -248,26 +253,42 @@ def test_run_task_does_nothing_if_no_tasks(db): @pytest.mark.django_db(transaction=True) -def test_run_task_runs_tasks_in_correct_order(): +def test_run_task_runs_tasks_in_correct_priority(): # Given # 2 tasks task_1 = Task.create( - _create_organisation.task_identifier, args=("task 1 organisation",) + _create_organisation.task_identifier, + scheduled_for=timezone.now(), + args=("task 1 organisation",), + priority=TaskPriority.HIGH, ) task_1.save() task_2 = Task.create( - _create_organisation.task_identifier, args=("task 2 organisation",) + _create_organisation.task_identifier, + scheduled_for=timezone.now(), + args=("task 2 organisation",), + priority=TaskPriority.HIGH, ) task_2.save() + task_3 = Task.create( + _create_organisation.task_identifier, + scheduled_for=timezone.now(), + args=("task 3 organisation",), + priority=TaskPriority.HIGHEST, + ) + task_3.save() + # When task_runs_1 = run_tasks() task_runs_2 = run_tasks() + task_runs_3 = run_tasks() # Then - assert task_runs_1[0].task == task_1 - assert task_runs_2[0].task == task_2 + assert task_runs_1[0].task == task_3 + assert task_runs_2[0].task == task_1 + assert task_runs_3[0].task == task_2 @pytest.mark.django_db(transaction=True) @@ -280,12 +301,16 @@ def test_run_tasks_skips_locked_tasks(): # 2 tasks # One which is configured to just sleep for 3 seconds, to simulate a task # being held for a short period of time - task_1 = Task.create(_sleep.task_identifier, args=(3,)) + task_1 = Task.create( + _sleep.task_identifier, scheduled_for=timezone.now(), args=(3,) + ) task_1.save() # and another which should create an organisation task_2 = Task.create( - _create_organisation.task_identifier, args=("task 2 organisation",) + _create_organisation.task_identifier, + scheduled_for=timezone.now(), + args=("task 2 organisation",), ) task_2.save() @@ -313,7 +338,11 @@ def test_run_more_than_one_task(db): for _ in range(num_tasks): organisation_name = f"test-org-{uuid.uuid4()}" tasks.append( - Task.create(_create_organisation.task_identifier, args=(organisation_name,)) + Task.create( + _create_organisation.task_identifier, + scheduled_for=timezone.now(), + args=(organisation_name,), + ) ) Task.objects.bulk_create(tasks)