Skip to content

Commit

Permalink
fix(17): Add timeout to auto unlock recurring tasks (#16)
Browse files Browse the repository at this point in the history
* fix(17/recurring-task-lock): Add timeout to auto unlock task

* fix locks for non recurring tasks

* remove auto unlock from tasks
  • Loading branch information
gagantrivedi authored Jan 14, 2025
1 parent f92adfd commit 4722bf1
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 20 deletions.
8 changes: 8 additions & 0 deletions task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]):
"priority",
"transaction_on_commit",
"task_identifier",
"timeout",
)

unwrapped: typing.Callable[P, None]
Expand All @@ -38,11 +39,13 @@ def __init__(
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = None,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit
self.timeout = timeout

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
Expand Down Expand Up @@ -87,6 +90,7 @@ def delay(
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
timeout=self.timeout,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = timedelta(seconds=60),
) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.
Expand All @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
timeout=timeout,
)

return wrapper
Expand All @@ -161,6 +167,7 @@ def register_recurring_task(
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
timeout: timedelta | None = timedelta(minutes=30),
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
Expand All @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask:
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
},
)
return task
Expand Down
4 changes: 2 additions & 2 deletions task_processor/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def get_tasks_to_process(self, num_tasks: int) -> QuerySet["Task"]:


class RecurringTaskManager(Manager):
def get_tasks_to_process(self, num_tasks: int) -> QuerySet["RecurringTask"]:
return self.raw("SELECT * FROM get_recurringtasks_to_process(%s)", [num_tasks])
def get_tasks_to_process(self) -> QuerySet["RecurringTask"]:
return self.raw("SELECT * FROM get_recurringtasks_to_process()")
39 changes: 39 additions & 0 deletions task_processor/migrations/0012_add_locked_at_and_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generated by Django 3.2.23 on 2025-01-06 04:51

from task_processor.migrations.helpers import PostgresOnlyRunSQL
import datetime
from django.db import migrations, models
import os


class Migration(migrations.Migration):

dependencies = [
("task_processor", "0011_add_priority_to_get_tasks_to_process"),
]

operations = [
migrations.AddField(
model_name="recurringtask",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="recurringtask",
name="timeout",
field=models.DurationField(default=datetime.timedelta(minutes=30)),
),
migrations.AddField(
model_name="task",
name="timeout",
field=models.DurationField(blank=True, null=True),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_recurringtasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process()",
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
CREATE OR REPLACE FUNCTION get_recurringtasks_to_process()
RETURNS SETOF task_processor_recurringtask AS $$
DECLARE
row_to_return task_processor_recurringtask;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_recurringtask
-- Add one minute to the timeout as a grace period for overhead
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout + INTERVAL '1 minute')
ORDER BY id
LIMIT 1
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_recurringtask
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
row_to_return.locked_at := NOW();
RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

13 changes: 12 additions & 1 deletion task_processor/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
import uuid
from datetime import datetime
from datetime import datetime, timedelta

import simplejson as json
from django.core.serializers.json import DjangoJSONEncoder
Expand Down Expand Up @@ -80,6 +80,8 @@ def callable(self) -> typing.Callable:
class Task(AbstractBaseTask):
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)

timeout = models.DurationField(blank=True, null=True)

# denormalise failures and completion so that we can use select_for_update
num_failures = models.IntegerField(default=0)
completed = models.BooleanField(default=False)
Expand Down Expand Up @@ -109,6 +111,7 @@ def create(
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
timeout: timedelta | None = timedelta(seconds=60),
) -> "Task":
if queue_size and cls._is_queue_full(task_identifier, queue_size):
raise TaskQueueFullError(
Expand All @@ -121,6 +124,7 @@ def create(
priority=priority,
serialized_args=cls.serialize_data(args or tuple()),
serialized_kwargs=cls.serialize_data(kwargs or dict()),
timeout=timeout,
)

@classmethod
Expand All @@ -147,6 +151,9 @@ class RecurringTask(AbstractBaseTask):
run_every = models.DurationField()
first_run_time = models.TimeField(blank=True, null=True)

locked_at = models.DateTimeField(blank=True, null=True)
timeout = models.DurationField(default=timedelta(minutes=30))

objects = RecurringTaskManager()

class Meta:
Expand All @@ -157,6 +164,10 @@ class Meta:
),
]

def unlock(self):
self.is_locked = False
self.locked_at = None

@property
def should_execute(self) -> bool:
now = timezone.now()
Expand Down
30 changes: 19 additions & 11 deletions task_processor/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback
import typing
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

from django.utils import timezone
Expand Down Expand Up @@ -36,7 +37,8 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:

if executed_tasks:
Task.objects.bulk_update(
executed_tasks, fields=["completed", "num_failures", "is_locked"]
executed_tasks,
fields=["completed", "num_failures", "is_locked"],
)

if task_runs:
Expand All @@ -48,14 +50,11 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:
return []


def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:
if num_tasks < 1:
raise ValueError("Number of tasks to process must be at least one")

def run_recurring_tasks() -> typing.List[RecurringTaskRun]:
# NOTE: We will probably see a lot of delay in the execution of recurring tasks
# if the tasks take longer then `run_every` to execute. This is not
# a problem for now, but we should be mindful of this limitation
tasks = RecurringTask.objects.get_tasks_to_process(num_tasks)
tasks = RecurringTask.objects.get_tasks_to_process()
if tasks:
task_runs = []

Expand All @@ -78,7 +77,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:

# update all tasks that were not deleted
to_update = [task for task in tasks if task.id]
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])
RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"])

if task_runs:
RecurringTaskRun.objects.bulk_create(task_runs)
Expand All @@ -93,16 +92,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
task_run = task.task_runs.model(started_at=timezone.now(), task=task)

try:
task.run()
task_run.result = TaskResult.SUCCESS
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(task.run)
timeout = task.timeout.total_seconds() if task.timeout else None
future.result(timeout=timeout) # Wait for completion or timeout

task_run.result = TaskResult.SUCCESS
task_run.finished_at = timezone.now()
task.mark_success()

except Exception as e:
# For errors that don't include a default message (e.g., TimeoutError),
# fall back to using repr.
err_msg = str(e) or repr(e)

logger.error(
"Failed to execute task '%s'. Exception was: %s",
"Failed to execute task '%s', with id %d. Exception: %s",
task.task_identifier,
str(e),
task.id,
err_msg,
exc_info=True,
)
logger.debug("args: %s", str(task.args))
Expand Down
2 changes: 1 addition & 1 deletion task_processor/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run(self) -> None:
def run_iteration(self) -> None:
try:
run_tasks(self.queue_pop_size)
run_recurring_tasks(self.queue_pop_size)
run_recurring_tasks()
except Exception as e:
# To prevent task threads from dying if they get an error retrieving the tasks from the
# database this will allow the thread to continue trying to retrieve tasks if it can
Expand Down
Loading

0 comments on commit 4722bf1

Please sign in to comment.