Skip to content

Commit

Permalink
fix locks for non recurring tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
gagantrivedi committed Jan 8, 2025
1 parent 6e07161 commit ed052fa
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 8 deletions.
25 changes: 23 additions & 2 deletions task_processor/migrations/0012_add_locked_at_and_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class Migration(migrations.Migration):
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="task",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="recurringtask",
name="timeout",
Expand All @@ -26,14 +31,30 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="task",
name="timeout",
field=models.DurationField(blank=True, null=True),
field=models.DurationField(default=datetime.timedelta(minutes=1)),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_recurringtasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process",
reverse_sql=os.path.join(
os.path.dirname(__file__),
"sql",
"0008_get_recurringtasks_to_process.sql",
),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_tasks_to_process.sql",
),
reverse_sql=os.path.join(
os.path.dirname(__file__),
"sql",
"0011_get_tasks_to_process.sql",
),
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ BEGIN
FOR row_to_return IN
SELECT *
FROM task_processor_recurringtask
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)
ORDER BY id
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
Expand Down
31 changes: 31 additions & 0 deletions task_processor/migrations/sql/0012_get_tasks_to_process.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer)
RETURNS SETOF task_processor_task AS $$
DECLARE
row_to_return task_processor_task;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_task
WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND (is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout))
ORDER BY priority ASC, scheduled_for ASC, created_at ASC
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_task
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

7 changes: 4 additions & 3 deletions task_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class AbstractBaseTask(models.Model):
serialized_kwargs = models.TextField(blank=True, null=True)
is_locked = models.BooleanField(default=False)

locked_at = models.DateTimeField(blank=True, null=True)

class Meta:
abstract = True

Expand Down Expand Up @@ -81,7 +83,7 @@ def callable(self) -> typing.Callable:
class Task(AbstractBaseTask):
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)

timeout = models.DurationField(null=True, blank=True)
timeout = models.DurationField(default=timedelta(minutes=1))

# denormalise failures and completion so that we can use select_for_update
num_failures = models.IntegerField(default=0)
Expand Down Expand Up @@ -112,7 +114,7 @@ def create(
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
timeout: timedelta | None = None,
timeout: timedelta | None = timedelta(seconds=60),
) -> "Task":
if queue_size and cls._is_queue_full(task_identifier, queue_size):
raise TaskQueueFullError(
Expand Down Expand Up @@ -151,7 +153,6 @@ def mark_success(self):
class RecurringTask(AbstractBaseTask):
run_every = models.DurationField()
first_run_time = models.TimeField(blank=True, null=True)
locked_at = models.DateTimeField(blank=True, null=True)

timeout = models.DurationField(default=timedelta(minutes=30))

Expand Down
3 changes: 2 additions & 1 deletion task_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:

if executed_tasks:
Task.objects.bulk_update(
executed_tasks, fields=["completed", "num_failures", "is_locked"]
executed_tasks,
fields=["completed", "num_failures", "is_locked", "locked_at"],
)

if task_runs:
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/task_processor/test_unit_task_processor_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,51 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db):
assert task.completed


def test_run_tasks_runs_locked_task_after_tiemout(
db: None,
) -> None:
# Given
assert cache.get(DEFAULT_CACHE_KEY) is None

task = Task.create(
_dummy_task.task_identifier,
timeout=timedelta(seconds=10),
scheduled_for=timezone.now(),
)
task.is_locked = True
task.locked_at = timezone.now() - timedelta(minutes=1)
task.save()

# @register_recurring_task(run_every=timedelta(hours=1))
# def _dummy_recurring_task():
# cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE)

# task = RecurringTask.objects.get(
# task_identifier=_dummy_recurring_task.task_identifier
# )
# task.is_locked = True
# task.locked_at = timezone.now() - timedelta(hours=1)
# task.save()

# When
task_runs = run_tasks()

# Then
assert cache.get(DEFAULT_CACHE_KEY) == DEFAULT_CACHE_VALUE

assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1
task_run = task_runs[0]
assert task_run.result == TaskResult.SUCCESS
assert task_run.started_at
assert task_run.finished_at
assert task_run.error_details is None

# And the task is no longer locked
task.refresh_from_db()
assert task.is_locked is False
assert task.locked_at is None


def test_run_task_kills_task_after_timeout(
db: None,
get_task_processor_caplog: "GetTaskProcessorCaplog",
Expand Down Expand Up @@ -182,6 +227,8 @@ def test_run_recurring_tasks_runs_locked_task_after_tiemout(
# Given
monkeypatch.setenv("RUN_BY_PROCESSOR", "True")

assert cache.get(DEFAULT_CACHE_KEY) is None

@register_recurring_task(run_every=timedelta(hours=1))
def _dummy_recurring_task():
cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE)
Expand All @@ -197,7 +244,7 @@ def _dummy_recurring_task():
task_runs = run_recurring_tasks()

# Then
assert cache.get(DEFAULT_CACHE_KEY)
assert cache.get(DEFAULT_CACHE_KEY) == DEFAULT_CACHE_VALUE

assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1
task_run = task_runs[0]
Expand Down

0 comments on commit ed052fa

Please sign in to comment.