Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure recurring tasks are unlocked after being picked up (but not executed) #2508

Merged
merged 3 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions api/task_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def deserialize_data(data: typing.Any):
return json.loads(data)

def mark_failure(self):
self.is_locked = False
self.unlock()

def mark_success(self):
self.unlock()

def unlock(self):
self.is_locked = False

def run(self):
Expand Down Expand Up @@ -119,7 +122,7 @@ def mark_failure(self):
self.num_failures += 1

def mark_success(self):
super().mark_failure()
super().mark_success()
self.completed = True


Expand Down
13 changes: 7 additions & 6 deletions api/task_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:
return []


def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]:
def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:
if num_tasks < 1:
raise ValueError("Number of tasks to process must be at least one")

Expand All @@ -55,7 +55,6 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]:
tasks = RecurringTask.objects.get_tasks_to_process(num_tasks)
if tasks:
task_runs = []
executed_tasks = []

for task in tasks:
# Remove the task if it's not registered anymore
Expand All @@ -65,11 +64,13 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]:

if task.should_execute:
task, task_run = _run_task(task)
executed_tasks.append(task)
task_runs.append(task_run)
else:
task.unlock()

if executed_tasks:
RecurringTask.objects.bulk_update(executed_tasks, fields=["is_locked"])
# update all tasks that were not deleted
to_update = [task for task in tasks if task.id]
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])

if task_runs:
RecurringTaskRun.objects.bulk_create(task_runs)
Expand All @@ -80,7 +81,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]:
return []


def _run_task(task: Task) -> typing.Optional[typing.Tuple[Task, TaskRun]]:
def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, TaskRun]:
task_run = task.task_runs.model(started_at=timezone.now(), task=task)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from threading import Thread

import pytest
from django.utils import timezone

from organisations.models import Organisation
from task_processor.decorators import (
Expand Down Expand Up @@ -80,7 +81,7 @@ def test_run_recurring_tasks_multiple_runs(db, run_by_processor):
task_identifier = "test_unit_task_processor_processor._create_organisation"

@register_recurring_task(
run_every=timedelta(milliseconds=100), args=(organisation_name,)
run_every=timedelta(milliseconds=200), args=(organisation_name,)
)
def _create_organisation(organisation_name):
Organisation.objects.create(name=organisation_name)
Expand All @@ -89,17 +90,28 @@ def _create_organisation(organisation_name):

# When
first_task_runs = run_recurring_tasks()
time.sleep(0.2)

second_task_runs = run_recurring_tasks()
# run the process again before the task is scheduled to run again to ensure
# that tasks are unlocked when they are picked up by the task processor but
# not executed.
no_task_runs = run_recurring_tasks()

task_runs = first_task_runs + second_task_runs
time.sleep(0.3)

second_task_runs = run_recurring_tasks()

# Then
assert len(first_task_runs) == 1
assert len(no_task_runs) == 0
assert len(second_task_runs) == 1

# we should still only have 2 organisations, despite executing the
# `run_recurring_tasks` function 3 times.
assert Organisation.objects.filter(name=organisation_name).count() == 2

assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 2
for task_run in task_runs:
all_task_runs = first_task_runs + second_task_runs
assert len(all_task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 2
for task_run in all_task_runs:
assert task_run.result == TaskResult.SUCCESS
assert task_run.started_at
assert task_run.finished_at
Expand Down Expand Up @@ -322,6 +334,37 @@ def test_run_more_than_one_task(db):
assert task.completed


def test_recurring_tasks_are_unlocked_if_picked_up_but_not_executed(
db, run_by_processor
):
# Given
@register_recurring_task(run_every=timedelta(days=1))
def my_task():
pass

recurring_task = RecurringTask.objects.get(
task_identifier="test_unit_task_processor_processor.my_task"
)

# mimic the task having already been run so that it is next picked up,
# but not executed
now = timezone.now()
one_minute_ago = now - timedelta(minutes=1)
RecurringTaskRun.objects.create(
task=recurring_task,
started_at=one_minute_ago,
finished_at=now,
result=TaskResult.SUCCESS.name,
)

# When
run_recurring_tasks()

# Then
recurring_task.refresh_from_db()
assert recurring_task.is_locked is False


@register_task_handler()
def _create_organisation(name: str):
"""function used to test that task is being run successfully"""
Expand Down