Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(task-processor): Add priority support #2847

Merged
merged 16 commits into from
Oct 31, 2023
Merged
7 changes: 4 additions & 3 deletions api/audit/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
)
from audit.models import AuditLog, RelatedObjectType
from task_processor.decorators import register_task_handler
from task_processor.models import TaskPriority

logger = logging.getLogger(__name__)


@register_task_handler()
@register_task_handler(priority=TaskPriority.HIGHEST)
def create_feature_state_went_live_audit_log(feature_state_id: int):
_create_feature_state_audit_log_for_change_request(
feature_state_id, FEATURE_STATE_WENT_LIVE_MESSAGE
)


@register_task_handler()
@register_task_handler(priority=TaskPriority.HIGHEST)
def create_feature_state_updated_by_change_request_audit_log(feature_state_id: int):
_create_feature_state_audit_log_for_change_request(
feature_state_id, FEATURE_STATE_UPDATED_BY_CHANGE_REQUEST_MESSAGE
Expand Down Expand Up @@ -57,7 +58,7 @@ def _create_feature_state_audit_log_for_change_request(
)


@register_task_handler()
@register_task_handler(priority=TaskPriority.HIGHEST)
def create_audit_log_from_historical_record(
history_instance_id: int,
history_user_id: typing.Optional[int],
Expand Down
15 changes: 12 additions & 3 deletions api/core/migration_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import typing
import uuid
from contextlib import suppress

from django.db import migrations

Expand All @@ -8,9 +10,16 @@

class PostgresOnlyRunSQL(migrations.RunSQL):
@classmethod
def from_sql_file(cls, file_path: str, reverse_sql: str) -> "PostgresOnlyRunSQL":
with open(file_path) as f:
return cls(f.read(), reverse_sql=reverse_sql)
def from_sql_file(
cls,
file_path: typing.Union[str, os.PathLike],
reverse_sql: typing.Union[str, os.PathLike] = None,
) -> "PostgresOnlyRunSQL":
with open(file_path) as forward_sql:
with suppress(FileNotFoundError, TypeError):
with open(reverse_sql) as reverse_sql_file:
reverse_sql = reverse_sql_file.read()
return cls(forward_sql.read(), reverse_sql=reverse_sql)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
if schema_editor.connection.vendor != "postgresql":
Expand Down
7 changes: 4 additions & 3 deletions api/edge_api/identities/edge_request_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

from environments.dynamodb.migrator import IdentityMigrator
from task_processor.decorators import register_task_handler
from task_processor.models import TaskPriority


def _should_forward(project_id: int) -> bool:
migrator = IdentityMigrator(project_id)
return bool(migrator.is_migration_done)


@register_task_handler(queue_size=2000)
@register_task_handler(queue_size=2000, priority=TaskPriority.LOW)
def forward_identity_request(
request_method: str,
headers: dict,
Expand All @@ -35,7 +36,7 @@ def forward_identity_request(
requests.get(url, params=query_params, headers=headers, timeout=5)


@register_task_handler(queue_size=2000)
@register_task_handler(queue_size=2000, priority=TaskPriority.LOW)
def forward_trait_request(
request_method: str,
headers: dict,
Expand All @@ -61,7 +62,7 @@ def forward_trait_request_sync(
)


@register_task_handler(queue_size=1000)
@register_task_handler(queue_size=1000, priority=TaskPriority.LOW)
def forward_trait_requests(
request_method: str,
headers: str,
Expand Down
3 changes: 2 additions & 1 deletion api/edge_api/identities/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from environments.models import Environment, Webhook
from features.models import Feature, FeatureState
from task_processor.decorators import register_task_handler
from task_processor.models import TaskPriority
from users.models import FFAdminUser
from webhooks.webhooks import WebhookEventType, call_environment_webhooks

Expand Down Expand Up @@ -71,7 +72,7 @@ def call_environment_webhook_for_feature_state_change(
call_environment_webhooks(environment, data, event_type=event_type)


@register_task_handler()
@register_task_handler(priority=TaskPriority.HIGH)
def sync_identity_document_features(identity_uuid: str):
from .models import EdgeIdentity

Expand Down
5 changes: 3 additions & 2 deletions api/environments/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
send_environment_update_message_for_project,
)
from task_processor.decorators import register_task_handler
from task_processor.models import TaskPriority


@register_task_handler()
@register_task_handler(priority=TaskPriority.HIGH)
def rebuild_environment_document(environment_id: int):
wrapper = DynamoEnvironmentWrapper()
if wrapper.is_enabled:
environment = Environment.objects.get(id=environment_id)
wrapper.write_environment(environment)


@register_task_handler()
@register_task_handler(priority=TaskPriority.HIGHEST)
def process_environment_update(audit_log_id: int):
audit_log = AuditLog.objects.get(id=audit_log_id)

Expand Down
13 changes: 9 additions & 4 deletions api/task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
from django.utils import timezone

from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError
from task_processor.models import RecurringTask, Task
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod

logger = logging.getLogger(__name__)


def register_task_handler(task_name: str = None, queue_size: int = None):
def register_task_handler(
task_name: str = None,
queue_size: int = None,
priority: TaskPriority = TaskPriority.NORMAL,
):
def decorator(f: typing.Callable):
nonlocal task_name

Expand Down Expand Up @@ -50,9 +54,10 @@ def delay(
else:
logger.debug("Creating task for function '%s'...", task_identifier)
try:
task = Task.schedule_task(
schedule_for=delay_until or timezone.now(),
task = Task.create(
task_identifier=task_identifier,
scheduled_for=delay_until or timezone.now(),
priority=priority,
queue_size=queue_size,
args=args,
kwargs=kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Migration(migrations.Migration):
os.path.join(
os.path.dirname(__file__),
"sql",
"get_tasks_to_process.sql",
"0008_get_tasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_tasks_to_process",
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"get_recurring_tasks_to_process.sql",
"0008_get_recurring_tasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process",
),
Expand Down
18 changes: 18 additions & 0 deletions api/task_processor/migrations/0010_task_priority.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 3.2.20 on 2023-10-13 06:04

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('task_processor', '0009_add_recurring_task_run_first_run_at'),
]

operations = [
migrations.AddField(
model_name='task',
name='priority',
field=models.PositiveSmallIntegerField(choices=[(100, 'Lower'), (75, 'Low'), (50, 'Normal'), (25, 'High'), (0, 'Highest')], default=None, null=True),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 3.2.20 on 2023-10-13 04:44

from django.db import migrations

from core.migration_helpers import PostgresOnlyRunSQL
import os


class Migration(migrations.Migration):
dependencies = [
("task_processor", "0010_task_priority"),
]

operations = [
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0011_get_tasks_to_process.sql",
),
reverse_sql=os.path.join(
os.path.dirname(__file__),
"sql",
"0008_get_tasks_to_process.sql",
),
),
]
30 changes: 30 additions & 0 deletions api/task_processor/migrations/sql/0011_get_tasks_to_process.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer)
RETURNS SETOF task_processor_task AS $$
DECLARE
row_to_return task_processor_task;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_task
WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND is_locked = FALSE
ORDER BY priority ASC, scheduled_for ASC, created_at ASC
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_task
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE
WHERE id = row_to_return.id;
-- If we don't explicitly update the `is_locked` column here, the client will receive the row that is actually locked but has the `is_locked` value set to `False`.
row_to_return.is_locked := TRUE;
RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

56 changes: 30 additions & 26 deletions api/task_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from task_processor.task_registry import registered_tasks


class TaskPriority(models.IntegerChoices):
LOWER = 100
LOW = 75
NORMAL = 50
HIGH = 25
HIGHEST = 0


class AbstractBaseTask(models.Model):
uuid = models.UUIDField(unique=True, default=uuid.uuid4)
created_at = models.DateTimeField(auto_now_add=True)
Expand Down Expand Up @@ -74,6 +82,9 @@ class Task(AbstractBaseTask):
num_failures = models.IntegerField(default=0)
completed = models.BooleanField(default=False)
objects = TaskManager()
priority = models.PositiveSmallIntegerField(
default=None, null=True, choices=TaskPriority.choices
)

class Meta:
# We have customised the migration in 0004 to only apply this change to postgres databases
Expand All @@ -90,44 +101,37 @@ class Meta:
def create(
cls,
task_identifier: str,
scheduled_for: datetime,
priority: TaskPriority = TaskPriority.NORMAL,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewelwell think we can get rid of schedule_task and just use this method directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but I quite like having the helper method to differentiate them. What's the benefit in removing it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name is slightly odd, no? I'd expect it to schedule the task(i.e: put on some queue?) but it only creates the class instance. Does that make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you mean that it doesn't actually persist it to the DB? But that's the same as the create method too, right? We could update the name to create_scheduled_task ? Maybe you're right though that we should consolidate the 2 methods.

Copy link
Member Author

@gagantrivedi gagantrivedi Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean that it doesn't actually persist it to the DB?

Yes, exactly

Copy link
Member Author

@gagantrivedi gagantrivedi Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed schedule_task method, but now I am not sure if it looks any better 😕 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, there's less code so that's always a good thing. I think I would prefer to keep the logic regarding the queue in the Task model though. Can we just handle that in the create method now instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, done

queue_size: int = None,
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
) -> "Task":
if queue_size:
if cls.is_queue_full(task_identifier, queue_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick:

Suggested change
if queue_size:
if cls.is_queue_full(task_identifier, queue_size):
if queue_size and cls.is_queue_full(task_identifier, queue_size):

raise TaskQueueFullError(
f"Queue for task {task_identifier} is full. "
f"Max queue size is {queue_size}"
)
return Task(
task_identifier=task_identifier,
scheduled_for=scheduled_for,
priority=priority,
serialized_args=cls.serialize_data(args or tuple()),
serialized_kwargs=cls.serialize_data(kwargs or dict()),
)

@classmethod
def schedule_task(
cls,
schedule_for: datetime,
task_identifier: str,
queue_size: typing.Optional[int],
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
) -> "Task":
if queue_size:
if (
cls.objects.filter(
task_identifier=task_identifier, completed=False, num_failures__lt=3
).count()
> queue_size
):
raise TaskQueueFullError(
f"Queue for task {task_identifier} is full. "
f"Max queue size is {queue_size}"
)
task = cls.create(
task_identifier=task_identifier,
args=args,
kwargs=kwargs,
def is_queue_full(cls, task_identifier: str, queue_size: int) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can probably be marked as private now, right?

Suggested change
def is_queue_full(cls, task_identifier: str, queue_size: int) -> bool:
def _is_queue_full(cls, task_identifier: str, queue_size: int) -> bool:

return (
cls.objects.filter(
task_identifier=task_identifier,
completed=False,
num_failures__lt=3,
).count()
> queue_size
)
task.scheduled_for = schedule_for
return task

def mark_failure(self):
super().mark_failure()
Expand Down
Loading