Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tasks/queue-size): Implement queue_size #2826

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions api/edge_api/identities/edge_request_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _should_forward(project_id: int) -> bool:
return bool(migrator.is_migration_done)


@register_task_handler()
@register_task_handler(queue_size=2000)
def forward_identity_request(
request_method: str,
headers: dict,
Expand All @@ -35,7 +35,7 @@ def forward_identity_request(
requests.get(url, params=query_params, headers=headers, timeout=5)


@register_task_handler()
@register_task_handler(queue_size=2000)
def forward_trait_request(
request_method: str,
headers: dict,
Expand All @@ -52,7 +52,6 @@ def forward_trait_request_sync(
return

url = settings.EDGE_API_URL + "traits/"
payload = payload
payload = json.dumps(payload)
requests.post(
url,
Expand All @@ -62,7 +61,7 @@ def forward_trait_request_sync(
)


@register_task_handler()
@register_task_handler(queue_size=1000)
def forward_trait_requests(
request_method: str,
headers: str,
Expand Down
22 changes: 14 additions & 8 deletions api/task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from django.conf import settings
from django.utils import timezone

from task_processor.exceptions import InvalidArgumentsError
from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError
from task_processor.models import RecurringTask, Task
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod

logger = logging.getLogger(__name__)


def register_task_handler(task_name: str = None):
def register_task_handler(task_name: str = None, queue_size: int = None):
def decorator(f: typing.Callable):
nonlocal task_name

Expand Down Expand Up @@ -49,12 +49,18 @@ def delay(
run_in_thread(args=args, kwargs=kwargs)
else:
logger.debug("Creating task for function '%s'...", task_identifier)
task = Task.schedule_task(
schedule_for=delay_until or timezone.now(),
task_identifier=task_identifier,
args=args,
kwargs=kwargs,
)
try:
task = Task.schedule_task(
schedule_for=delay_until or timezone.now(),
task_identifier=task_identifier,
queue_size=queue_size,
args=args,
kwargs=kwargs,
)
except TaskQueueFullError as e:
logger.warning(e)
return

task.save()
return task

Expand Down
4 changes: 4 additions & 0 deletions api/task_processor/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class TaskProcessingError(Exception):

class InvalidArgumentsError(TaskProcessingError):
pass


class TaskQueueFullError(Exception):
pass
14 changes: 13 additions & 1 deletion api/task_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from django.db import models
from django.utils import timezone

from task_processor.exceptions import TaskProcessingError
from task_processor.exceptions import TaskProcessingError, TaskQueueFullError
from task_processor.managers import RecurringTaskManager, TaskManager
from task_processor.task_registry import registered_tasks

Expand Down Expand Up @@ -105,10 +105,22 @@ def schedule_task(
cls,
schedule_for: datetime,
task_identifier: str,
queue_size: typing.Optional[int],
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
) -> "Task":
if queue_size:
if (
cls.objects.filter(
task_identifier=task_identifier, completed=False, num_failures__lt=3
).count()
> queue_size
):
raise TaskQueueFullError(
f"Queue for task {task_identifier} is full. "
f"Max queue size is {queue_size}"
)
task = cls.create(
task_identifier=task_identifier,
args=args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils import timezone

from task_processor.decorators import register_task_handler
from task_processor.exceptions import TaskQueueFullError
from task_processor.models import RecurringTask, Task

now = timezone.now()
Expand Down Expand Up @@ -54,3 +55,42 @@ def test_recurring_task_run_should_execute_first_run_at(first_run_time, expected
).should_execute
== expected
)


def test_schedule_task_raises_error_if_queue_is_full(db):
# Given
task_identifier = "my_callable"

# some incomplete task
for _ in range(10):
Task.objects.create(task_identifier=task_identifier)

# When
with pytest.raises(TaskQueueFullError):
Task.schedule_task(
schedule_for=timezone.now(), task_identifier=task_identifier, queue_size=9
)


def test_can_schedule_task_raises_error_if_queue_is_not_full(db):
# Given
task_identifier = "my_callable"

# Some incomplete task
for _ in range(10):
Task.objects.create(task_identifier=task_identifier)

# tasks with different identifiers
Task.objects.create(task_identifier="task_with_different_identifier")

# failed tasks
Task.objects.create(
task_identifier="task_with_different_identifier", num_failures=3
)

# When
task = Task.schedule_task(
schedule_for=timezone.now(), task_identifier=task_identifier, queue_size=10
)
# Then
assert task is not None