From 00ea322238031f52dde056463b4bfc0dc08b54c0 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 21 Dec 2024 10:34:15 +0200 Subject: [PATCH] Redis Streams for immediate task delivery (#492) --- arq/connections.py | 74 +++++++++++-- arq/constants.py | 3 + arq/jobs.py | 45 +++++++- arq/lua.py | 48 ++++++++ arq/worker.py | 259 ++++++++++++++++++++++++++++++++++++++----- tests/conftest.py | 18 ++- tests/test_lua.py | 122 ++++++++++++++++++++ tests/test_worker.py | 102 +++++++---------- 8 files changed, 571 insertions(+), 100 deletions(-) create mode 100644 arq/lua.py create mode 100644 tests/test_lua.py diff --git a/arq/connections.py b/arq/connections.py index c1058890..e2f2746d 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -13,8 +13,16 @@ from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError -from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix +from .constants import ( + default_queue_name, + expires_extra_ms, + job_key_prefix, + job_message_id_prefix, + result_key_prefix, + stream_key_suffix, +) from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job +from .lua import publish_job_lua from .utils import timestamp_ms, to_ms, to_unix_ms logger = logging.getLogger('arq.connections') @@ -165,20 +173,63 @@ async def enqueue_job( elif defer_by_ms: score = enqueue_time_ms + defer_by_ms else: - score = enqueue_time_ms + score = None - expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms + expires_ms = expires_ms or (score or enqueue_time_ms) - enqueue_time_ms + self.expires_extra_ms - job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) + job = serialize_job( + function, + args, + kwargs, + _job_try, + enqueue_time_ms, + serializer=self.job_serializer, + ) pipe.multi() pipe.psetex(job_key, expires_ms, job) - pipe.zadd(_queue_name, {job_id: score}) + + if score is not None: + pipe.zadd(_queue_name, {job_id: score}) + else: + stream_key = _queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job_id + pipe.eval( + publish_job_lua, + 2, + # keys + stream_key, + job_message_id_key, + # args + job_id, + str(enqueue_time_ms), + str(expires_ms), + ) + try: await pipe.execute() except WatchError: # job got enqueued since we checked 'job_exists' return None - return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + return Job( + job_id, + redis=self, + _queue_name=_queue_name, + _deserializer=self.job_deserializer, + ) + + async def get_queue_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int: + if queue_name is None: + queue_name = self.default_queue_name + + async with self.pipeline(transaction=True) as pipe: + pipe.xlen(queue_name + stream_key_suffix) + pipe.zcount(queue_name, '-inf', '+inf') + stream_size, delayed_queue_size = await pipe.execute() + + if not include_delayed_tasks: + return stream_size + + return stream_size + delayed_queue_size async def _get_job_result(self, key: bytes) -> JobResult: job_id = key[len(result_key_prefix) :].decode() @@ -213,7 +264,16 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] """ if queue_name is None: queue_name = self.default_queue_name - jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1) + + async with self.pipeline(transaction=True) as pipe: + pipe.zrange(queue_name, withscores=True, start=0, end=-1) + pipe.xrange(queue_name + stream_key_suffix, '-', '+') + delayed_jobs, stream_jobs = await pipe.execute() + + jobs = [ + *delayed_jobs, + *[(j[b'job_id'], int(j[b'score'])) for _, j in stream_jobs], + ] return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs]) diff --git a/arq/constants.py b/arq/constants.py index 84c009aa..98d6ff99 100644 --- a/arq/constants.py +++ b/arq/constants.py @@ -1,9 +1,12 @@ default_queue_name = 'arq:queue' job_key_prefix = 'arq:job:' in_progress_key_prefix = 'arq:in-progress:' +job_message_id_prefix = 'arq:message-id:' result_key_prefix = 'arq:result:' retry_key_prefix = 'arq:retry:' abort_jobs_ss = 'arq:abort' +stream_key_suffix = ':stream' +default_consumer_group = 'arq:workers' # age of items in the abort_key sorted set after which they're deleted abort_job_max_age = 60 health_check_key_suffix = ':health-check' diff --git a/arq/jobs.py b/arq/jobs.py index 15b7231e..71e8eef3 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -9,7 +9,16 @@ from redis.asyncio import Redis -from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix +from .constants import ( + abort_jobs_ss, + default_queue_name, + in_progress_key_prefix, + job_key_prefix, + job_message_id_prefix, + result_key_prefix, + stream_key_suffix, +) +from .lua import get_job_from_stream_lua from .utils import ms_to_datetime, poll, timestamp_ms logger = logging.getLogger('arq.jobs') @@ -63,6 +72,10 @@ class JobResult(JobDef): queue_name: str +def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]: + return dict(zip(input_list[::2], input_list[1::2], strict=True)) + + class Job: """ Holds data a reference to a job. @@ -105,7 +118,8 @@ async def result( async with self._redis.pipeline(transaction=True) as tr: tr.get(result_key_prefix + self.job_id) tr.zscore(self._queue_name, self.job_id) - v, s = await tr.execute() + tr.get(job_message_id_prefix + self.job_id) + v, s, m = await tr.execute() if v: info = deserialize_result(v, deserializer=self._deserializer) @@ -115,7 +129,7 @@ async def result( raise info.result else: raise SerializationError(info.result) - elif s is None: + elif s is None and m is None: raise ResultNotFound( 'Not waiting for job result because the job is not in queue. ' 'Is the worker function configured to keep result?' @@ -134,8 +148,24 @@ async def info(self) -> Optional[JobDef]: if v: info = deserialize_job(v, deserializer=self._deserializer) if info: - s = await self._redis.zscore(self._queue_name, self.job_id) - info.score = None if s is None else int(s) + async with self._redis.pipeline(transaction=True) as tr: + tr.zscore(self._queue_name, self.job_id) + tr.eval( + get_job_from_stream_lua, + 2, + self._queue_name + stream_key_suffix, + job_message_id_prefix + self.job_id, + ) + delayed_score, job_info = await tr.execute() + + if delayed_score: + info.score = int(delayed_score) + elif job_info: + _, job_info_payload = job_info + info.score = int(_list_to_dict(job_info_payload)[b'score']) + else: + info.score = None + return info async def result_info(self) -> Optional[JobResult]: @@ -157,12 +187,15 @@ async def status(self) -> JobStatus: tr.exists(result_key_prefix + self.job_id) tr.exists(in_progress_key_prefix + self.job_id) tr.zscore(self._queue_name, self.job_id) - is_complete, is_in_progress, score = await tr.execute() + tr.exists(job_message_id_prefix + self.job_id) + is_complete, is_in_progress, score, queued = await tr.execute() if is_complete: return JobStatus.complete elif is_in_progress: return JobStatus.in_progress + elif queued: + return JobStatus.queued elif score: return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued else: diff --git a/arq/lua.py b/arq/lua.py new file mode 100644 index 00000000..e7bd5230 --- /dev/null +++ b/arq/lua.py @@ -0,0 +1,48 @@ +publish_delayed_job_lua = """ +local delayed_queue_key = KEYS[1] +local stream_key = KEYS[2] +local job_message_id_key = KEYS[3] + +local job_id = ARGV[1] +local job_message_id_expire_ms = ARGV[2] + +local score = redis.call('zscore', delayed_queue_key, job_id) +if score == nil or score == false then + return 0 +end + +local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score) +redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms) +redis.call('zrem', delayed_queue_key, job_id) +return 1 +""" + +publish_job_lua = """ +local stream_key = KEYS[1] +local job_message_id_key = KEYS[2] + +local job_id = ARGV[1] +local score = ARGV[2] +local job_message_id_expire_ms = ARGV[3] + +local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score) +redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms) +return message_id +""" + +get_job_from_stream_lua = """ +local stream_key = KEYS[1] +local job_message_id_key = KEYS[2] + +local message_id = redis.call('get', job_message_id_key) +if message_id == false then + return nil +end + +local job = redis.call('xrange', stream_key, message_id, message_id) +if job == nil then + return nil +end + +return job[1] +""" diff --git a/arq/worker.py b/arq/worker.py index 8fcd5fc8..880dd48d 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -3,13 +3,16 @@ import inspect import logging import signal +from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import partial from signal import Signals from time import time from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from uuid import uuid4 +from redis.asyncio.client import Pipeline from redis.exceptions import ResponseError, WatchError from arq.cron import CronJob @@ -19,15 +22,19 @@ from .constants import ( abort_job_max_age, abort_jobs_ss, + default_consumer_group, default_queue_name, expires_extra_ms, health_check_key_suffix, in_progress_key_prefix, job_key_prefix, + job_message_id_prefix, keep_cronjob_progress, result_key_prefix, retry_key_prefix, + stream_key_suffix, ) +from .lua import publish_delayed_job_lua, publish_job_lua from .utils import ( args_to_string, import_string, @@ -57,6 +64,13 @@ class Function: max_tries: Optional[int] +@dataclass +class JobMetaInfo: + message_id: str + job_id: str + score: int + + def func( coroutine: Union[str, Function, 'WorkerCoroutine'], *, @@ -188,6 +202,8 @@ def __init__( functions: Sequence[Union[Function, 'WorkerCoroutine']] = (), *, queue_name: Optional[str] = default_queue_name, + consumer_group_name: str = default_consumer_group, + worker_id: Optional[str] = None, cron_jobs: Optional[Sequence[CronJob]] = None, redis_settings: Optional[RedisSettings] = None, redis_pool: Optional[ArqRedis] = None, @@ -204,6 +220,7 @@ def __init__( keep_result: 'SecondsTimedelta' = 3600, keep_result_forever: bool = False, poll_delay: 'SecondsTimedelta' = 0.5, + stream_block: 'SecondsTimedelta' = 0.5, queue_read_limit: Optional[int] = None, max_tries: int = 5, health_check_interval: 'SecondsTimedelta' = 3600, @@ -217,6 +234,8 @@ def __init__( expires_extra_ms: int = expires_extra_ms, timezone: Optional[timezone] = None, log_results: bool = True, + max_consumer_inactivity: 'SecondsTimedelta' = 86400, + idle_consumer_poll_interval: 'SecondsTimedelta' = 60, ): self.functions: Dict[str, Union[Function, CronJob]] = {f.name: f for f in map(func, functions)} if queue_name is None: @@ -225,6 +244,8 @@ def __init__( else: raise ValueError('If queue_name is absent, redis_pool must be present.') self.queue_name = queue_name + self.consumer_group_name = consumer_group_name + self.worker_id = worker_id or str(uuid4().hex) self.cron_jobs: List[CronJob] = [] if cron_jobs is not None: if not all(isinstance(cj, CronJob) for cj in cron_jobs): @@ -248,6 +269,9 @@ def __init__( self.keep_result_s = to_seconds(keep_result) self.keep_result_forever = keep_result_forever self.poll_delay_s = to_seconds(poll_delay) + self.stream_block_s = to_seconds(stream_block) + self.max_consumer_inactivity_s = to_seconds(max_consumer_inactivity) + self.idle_consumer_poll_interval_s = to_seconds(idle_consumer_poll_interval) self.queue_read_limit = queue_read_limit or max(max_jobs * 5, 100) self._queue_read_offset = 0 self.max_tries = max_tries @@ -357,19 +381,100 @@ async def main(self) -> None: if self.on_startup: await self.on_startup(self.ctx) - async for _ in poll(self.poll_delay_s): - await self._poll_iteration() + await self.create_consumer_group() + + done, pending = await asyncio.wait( + [ + asyncio.ensure_future(self.run_delayed_queue_poller()), + asyncio.ensure_future(self.run_stream_reader()), + asyncio.ensure_future(self.run_idle_consumer_cleanup()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + + await asyncio.gather(*pending, return_exceptions=True) + + for task in done: + task.result() + + async def run_stream_reader(self) -> None: + while True: + await self._read_stream_iteration() if self.burst: if 0 <= self.max_burst_jobs <= self._jobs_started(): await asyncio.gather(*self.tasks.values()) return None - queued_jobs = await self.pool.zcard(self.queue_name) + queued_jobs = await self.pool.get_queue_size(self.queue_name) if queued_jobs == 0: await asyncio.gather(*self.tasks.values()) return None - async def _poll_iteration(self) -> None: + async def run_delayed_queue_poller(self) -> None: + publish_delayed_job = self.pool.register_script(publish_delayed_job_lua) + + async for _ in poll(self.poll_delay_s): + job_ids = await self.pool.zrange( + self.queue_name, + start=float('-inf'), + end=timestamp_ms(), + num=self.queue_read_limit, + offset=self._queue_read_offset, + withscores=True, + byscore=True, + ) + async with self.pool.pipeline(transaction=False) as pipe: + for job_id, score in job_ids: + expire_ms = int(score - timestamp_ms() + self.expires_extra_ms) + if expire_ms <= 0: + expire_ms = self.expires_extra_ms + + await publish_delayed_job( + keys=[ + self.queue_name, + self.queue_name + stream_key_suffix, + job_message_id_prefix + job_id.decode(), + ], + args=[job_id.decode(), expire_ms], + client=pipe, + ) + + await pipe.execute() + + async def run_idle_consumer_cleanup(self) -> None: + async for _ in poll(self.idle_consumer_poll_interval_s): + consumers_info = await self.pool.xinfo_consumers( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + ) + + for consumer_info in consumers_info: + if self.worker_id == consumer_info['name'].decode(): + continue + + idle = timedelta(milliseconds=consumer_info['idle']).seconds + pending = consumer_info['pending'] + + if pending == 0 and idle > self.max_consumer_inactivity_s: + await self.pool.xgroup_delconsumer( + name=self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=consumer_info['name'], + ) + + async def create_consumer_group(self) -> None: + with suppress(ResponseError): + await self.pool.xgroup_create( + name=self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + id='0', + mkstream=True, + ) + + async def _read_stream_iteration(self) -> None: """ Get ids of pending jobs from the main queue sorted-set data structure and start those jobs, remove any finished tasks from self.tasks. @@ -382,12 +487,35 @@ async def _poll_iteration(self) -> None: count = min(burst_jobs_remaining, count) if self.allow_pick_jobs: if self.job_counter < self.max_jobs: - now = timestamp_ms() - job_ids = await self.pool.zrangebyscore( - self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now - ) + stream_msgs = await self._get_idle_tasks(count) + msgs_count = sum([len(msgs) for _, msgs in stream_msgs]) + + count -= msgs_count + + if count > 0: + stream_msgs.extend( + await self.pool.xreadgroup( + groupname=self.consumer_group_name, + consumername=self.worker_id, + streams={self.queue_name + stream_key_suffix: '>'}, + count=count, + block=int(max(self.stream_block_s * 1000, 1)), + ) + ) + + jobs = [] - await self.start_jobs(job_ids) + for _, msgs in stream_msgs: + for msg_id, job in msgs: + jobs.append( + JobMetaInfo( + message_id=msg_id.decode(), + job_id=job[b'job_id'].decode(), + score=int(job[b'score']), + ) + ) + + await self.start_jobs(jobs) if self.allow_abort_jobs: await self._cancel_aborted_jobs() @@ -400,6 +528,25 @@ async def _poll_iteration(self) -> None: await self.heart_beat() + async def _get_idle_tasks(self, count: int) -> list[tuple[bytes, list]]: + resp = await self.pool.xautoclaim( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=self.worker_id, + min_idle_time=int(self.in_progress_timeout_s * 1000), + count=count, + ) + + if not resp: + return [] + + _, msgs, __ = resp + if not msgs: + return [] + + # cast to the same format as the xreadgroup response + return [((self.queue_name + stream_key_suffix).encode(), msgs)] + async def _cancel_aborted_jobs(self) -> None: """ Go through job_ids in the abort_jobs_ss sorted set and cancel those tasks. @@ -428,11 +575,14 @@ def _release_sem_dec_counter_on_complete(self) -> None: self.job_counter = self.job_counter - 1 self.sem.release() - async def start_jobs(self, job_ids: List[bytes]) -> None: + async def start_jobs(self, jobs: list[JobMetaInfo]) -> None: """ For each job id, get the job definition, check it's not running and start it in a task """ - for job_id_b in job_ids: + for job in jobs: + job_id = job.job_id + score = job.score + await self.sem.acquire() if self.job_counter >= self.max_jobs: @@ -441,16 +591,15 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: self.job_counter = self.job_counter + 1 - job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id async with self.pool.pipeline(transaction=True) as pipe: await pipe.watch(in_progress_key) ongoing_exists = await pipe.exists(in_progress_key) - score = await pipe.zscore(self.queue_name, job_id) - if ongoing_exists or not score or score > timestamp_ms(): - # job already started elsewhere, or already finished and removed from queue - # if score > ts_now, - # it means probably the job was re-enqueued with a delay in another worker + + if ongoing_exists: + await pipe.unwatch() + await self._unclaim_job(job, pipe) + await pipe.execute() self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('job %s already running elsewhere', job_id) @@ -462,15 +611,40 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' + pipe.multi() + await self._unclaim_job(job, pipe) + await pipe.execute() self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('multi-exec error, job %s already started elsewhere', job_id) else: - t = self.loop.create_task(self.run_job(job_id, int(score))) + t = self.loop.create_task(self.run_job(job_id, job.message_id, score)) t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t - async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 + async def _unclaim_job(self, job: JobMetaInfo, pipe: Pipeline) -> None: + stream_key = self.queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job.job_id + + pipe.xack(stream_key, self.consumer_group_name, job.message_id) + pipe.xdel(stream_key, job.message_id) + job_message_id_expire = job.score - timestamp_ms() + self.expires_extra_ms + if job_message_id_expire <= 0: + job_message_id_expire = self.expires_extra_ms + + pipe.eval( + publish_job_lua, + 2, + # keys + stream_key, + job_message_id_key, + # args + job.job_id, + str(job.score), + str(job_message_id_expire), + ) + + async def run_job(self, job_id: str, message_id: str, score: int) -> None: # noqa: C901 start_ms = timestamp_ms() async with self.pool.pipeline(transaction=True) as pipe: pipe.get(job_key_prefix + job_id) @@ -504,7 +678,7 @@ async def job_failed(exc: BaseException) -> None: queue_name=self.queue_name, job_id=job_id, ) - await asyncio.shield(self.finish_failed_job(job_id, result_data_)) + await asyncio.shield(self.finish_failed_job(job_id, message_id, result_data_)) if not v: logger.warning('job %s expired', job_id) @@ -561,7 +735,7 @@ async def job_failed(exc: BaseException) -> None: job_id=job_id, serializer=self.job_serializer, ) - return await asyncio.shield(self.finish_failed_job(job_id, result_data)) + return await asyncio.shield(self.finish_failed_job(job_id, message_id, result_data)) result = no_result exc_extra = None @@ -662,6 +836,8 @@ async def job_failed(exc: BaseException) -> None: await asyncio.shield( self.finish_job( job_id, + message_id, + score, finish, result_data, result_timeout_s, @@ -677,6 +853,8 @@ async def job_failed(exc: BaseException) -> None: async def finish_job( self, job_id: str, + message_id: str, + score: int, finish: bool, result_data: Optional[bytes], result_timeout_s: Optional[float], @@ -687,33 +865,62 @@ async def finish_job( async with self.pool.pipeline(transaction=True) as tr: delete_keys = [] in_progress_key = in_progress_key_prefix + job_id + stream_key = self.queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job_id if keep_in_progress is None: delete_keys += [in_progress_key] else: tr.pexpire(in_progress_key, to_ms(keep_in_progress)) + tr.xack( + stream_key, + self.consumer_group_name, + message_id, + ) + tr.xdel(stream_key, message_id) + if finish: if result_data: expire = None if keep_result_forever else result_timeout_s tr.set(result_key_prefix + job_id, result_data, px=to_ms(expire)) - delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id] + delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id, job_message_id_key] tr.zrem(abort_jobs_ss, job_id) - tr.zrem(self.queue_name, job_id) elif incr_score: - tr.zincrby(self.queue_name, incr_score, job_id) + delete_keys += [job_message_id_key] + tr.zadd(self.queue_name, {job_id: score + incr_score}) + else: + job_message_id_expire = score - timestamp_ms() + self.expires_extra_ms + tr.eval( + publish_job_lua, + 2, + # keys + stream_key, + job_message_id_key, + # args + job_id, + str(score), + str(job_message_id_expire), + ) if delete_keys: tr.delete(*delete_keys) await tr.execute() - async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> None: + async def finish_failed_job(self, job_id: str, message_id: str, result_data: Optional[bytes]) -> None: + stream_key = self.queue_name + stream_key_suffix async with self.pool.pipeline(transaction=True) as tr: tr.delete( retry_key_prefix + job_id, in_progress_key_prefix + job_id, job_key_prefix + job_id, + job_message_id_prefix + job_id, ) tr.zrem(abort_jobs_ss, job_id) - tr.zrem(self.queue_name, job_id) + tr.xack( + stream_key, + self.consumer_group_name, + message_id, + ) + tr.xdel(stream_key, message_id) # result_data would only be None if serializing the result fails keep_result = self.keep_result_forever or self.keep_result_s > 0 if result_data is not None and keep_result: # pragma: no branch diff --git a/tests/conftest.py b/tests/conftest.py index 9b6b7f5b..b2123ed5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,10 +94,24 @@ async def arq_redis_retry(test_redis_host: str, test_redis_port: int): async def worker(arq_redis): worker_: Worker = None - def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis, **kwargs): + def create( + functions=[], + burst=True, + poll_delay=0, + stream_block=0, + max_jobs=10, + arq_redis=arq_redis, + **kwargs, + ): nonlocal worker_ worker_ = Worker( - functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs + functions=functions, + redis_pool=arq_redis, + burst=burst, + poll_delay=poll_delay, + max_jobs=max_jobs, + stream_block=stream_block, + **kwargs, ) return worker_ diff --git a/tests/test_lua.py b/tests/test_lua.py new file mode 100644 index 00000000..53e90e14 --- /dev/null +++ b/tests/test_lua.py @@ -0,0 +1,122 @@ +import pytest +from redis.commands.core import AsyncScript + +from arq import ArqRedis +from arq.lua import get_job_from_stream_lua, publish_delayed_job_lua, publish_job_lua + + +@pytest.fixture() +def publish_delayed_job(arq_redis: ArqRedis) -> AsyncScript: + return arq_redis.register_script(publish_delayed_job_lua) + + +@pytest.fixture() +def publish_job(arq_redis: ArqRedis) -> AsyncScript: + return arq_redis.register_script(publish_job_lua) + + +@pytest.fixture() +def get_job_from_stream(arq_redis: ArqRedis) -> AsyncScript: + return arq_redis.register_script(get_job_from_stream_lua) + + +async def test_publish_delayed_job(arq_redis: ArqRedis, publish_delayed_job: AsyncScript) -> None: + await arq_redis.zadd('delayed_queue_key', {'job_id': 1000}) + await publish_delayed_job( + keys=[ + 'delayed_queue_key', + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + ], + ) + + stream_msgs = await arq_redis.xrange('stream_key', '-', '+') + assert len(stream_msgs) == 1 + + saved_msg_id = await arq_redis.get('job_message_id_key') + + msg_id, msg = stream_msgs[0] + assert msg == {b'job_id': b'job_id', b'score': b'1000'} + assert saved_msg_id == msg_id + + assert await arq_redis.zrange('delayed_queue_key', '-inf', '+inf', byscore=True) == [] + + await publish_delayed_job( + keys=[ + 'delayed_queue_key', + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + ], + ) + + stream_msgs = await arq_redis.xrange('stream_key', '-', '+') + assert len(stream_msgs) == 1 + + saved_msg_id = await arq_redis.get('job_message_id_key') + assert saved_msg_id == msg_id + + +async def test_publish_job(arq_redis: ArqRedis, publish_job: AsyncScript) -> None: + msg_id = await publish_job( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + '1000', + ], + ) + + stream_msgs = await arq_redis.xrange('stream_key', '-', '+') + assert len(stream_msgs) == 1 + + saved_msg_id = await arq_redis.get('job_message_id_key') + assert saved_msg_id == msg_id + + msg_id, msg = stream_msgs[0] + assert msg == {b'job_id': b'job_id', b'score': b'1000'} + assert saved_msg_id == msg_id + + +async def test_get_job_from_stream( + arq_redis: ArqRedis, publish_job: AsyncScript, get_job_from_stream: AsyncScript +) -> None: + msg_id = await publish_job( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + '1000', + ], + ) + + job = await get_job_from_stream( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + ) + + assert job == [msg_id, [b'job_id', b'job_id', b'score', b'1000']] + + await arq_redis.delete('job_message_id_key') + job = await get_job_from_stream( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + ) + assert job is None diff --git a/tests/test_worker.py b/tests/test_worker.py index 93fbc7f0..cbc3102e 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -17,6 +17,7 @@ from arq.worker import ( FailedJobs, JobExecutionFailed, + JobMetaInfo, Retry, RetryJob, Worker, @@ -71,7 +72,7 @@ async def test_set_health_check_key(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') worker: Worker = worker(functions=[func(foobar, keep_result=0)], health_check_key='arq:test:health-check') await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:test:health-check'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:stream', b'arq:test:health-check'] async def test_handle_sig(caplog, arq_redis: ArqRedis): @@ -160,40 +161,6 @@ async def test_job_successful(arq_redis: ArqRedis, worker, caplog): assert 'X.XXs → testing:foobar()\n X.XXs ← testing:foobar ● 42' in log -async def test_job_retry_race_condition(arq_redis: ArqRedis, worker): - async def retry_job(ctx): - if ctx['job_try'] == 1: - raise Retry(defer=10) - - job_id = 'testing' - await arq_redis.enqueue_job('retry_job', _job_id=job_id) - - worker_one: Worker = worker(functions=[func(retry_job, name='retry_job')]) - worker_two: Worker = worker(functions=[func(retry_job, name='retry_job')]) - - assert worker_one.jobs_complete == 0 - assert worker_one.jobs_failed == 0 - assert worker_one.jobs_retried == 0 - - assert worker_two.jobs_complete == 0 - assert worker_two.jobs_failed == 0 - assert worker_two.jobs_retried == 0 - - await worker_one.start_jobs([job_id.encode()]) - await asyncio.gather(*worker_one.tasks.values()) - - await worker_two.start_jobs([job_id.encode()]) - await asyncio.gather(*worker_two.tasks.values()) - - assert worker_one.jobs_complete == 0 - assert worker_one.jobs_failed == 0 - assert worker_one.jobs_retried == 1 - - assert worker_two.jobs_complete == 0 - assert worker_two.jobs_failed == 0 - assert worker_two.jobs_retried == 0 - - async def test_job_successful_no_result_logging(arq_redis: ArqRedis, worker, caplog): caplog.set_level(logging.INFO) await arq_redis.enqueue_job('foobar', _job_id='testing') @@ -214,6 +181,8 @@ async def retry(ctx): await arq_redis.enqueue_job('retry', _job_id='testing') worker: Worker = worker(functions=[func(retry, name='retry')]) await worker.main() + + assert await worker.pool.get_queue_size(worker.queue_name) == 0 assert worker.jobs_complete == 1 assert worker.jobs_failed == 0 assert worker.jobs_retried == 2 @@ -517,38 +486,38 @@ async def test_log_health_check(arq_redis: ArqRedis, worker, caplog): async def test_remain_keys(test_redis_settings: RedisSettings, arq_redis: ArqRedis, worker, create_pool): redis2 = await create_pool(test_redis_settings) await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await redis2.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await redis2.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[foobar]) await worker.main() - assert sorted(await redis2.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + assert sorted(await redis2.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream', b'arq:result:testing'] await worker.close() - assert sorted(await redis2.keys('*')) == [b'arq:result:testing'] + assert sorted(await redis2.keys('*')) == [b'arq:queue:stream', b'arq:result:testing'] async def test_remain_keys_no_results(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[func(foobar, keep_result=0)]) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream'] async def test_remain_keys_keep_results_forever_in_function(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[func(foobar, keep_result_forever=True)]) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream', b'arq:result:testing'] ttl_result = await arq_redis.ttl('arq:result:testing') assert ttl_result == -1 async def test_remain_keys_keep_results_forever(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[func(foobar)], keep_result_forever=True) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream', b'arq:result:testing'] ttl_result = await arq_redis.ttl('arq:result:testing') assert ttl_result == -1 @@ -644,23 +613,24 @@ async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker): for _ in range(4): await arq_redis.enqueue_job('foobar') - assert await arq_redis.zcard(default_queue_name) == 4 + assert await arq_redis.get_queue_size(default_queue_name) == 4 worker: Worker = worker(functions=[foobar], queue_read_limit=2) assert worker.queue_read_limit == 2 assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 - await worker._poll_iteration() + await worker.create_consumer_group() + await worker._read_stream_iteration() await asyncio.sleep(0.1) - assert await arq_redis.zcard(default_queue_name) == 2 + assert await arq_redis.get_queue_size(default_queue_name) == 2 assert worker.jobs_complete == 2 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 - await worker._poll_iteration() + await worker._read_stream_iteration() await asyncio.sleep(0.1) - assert await arq_redis.zcard(default_queue_name) == 0 + assert await arq_redis.get_queue_size(default_queue_name) == 0 assert worker.jobs_complete == 4 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 @@ -677,20 +647,21 @@ async def test_custom_queue_read_limit(arq_redis: ArqRedis, worker): for _ in range(4): await arq_redis.enqueue_job('foobar') - assert await arq_redis.zcard(default_queue_name) == 4 + assert await arq_redis.get_queue_size(default_queue_name) == 4 worker: Worker = worker(functions=[foobar], max_jobs=4, queue_read_limit=2) assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 - await worker._poll_iteration() + await worker.create_consumer_group() + await worker._read_stream_iteration() await asyncio.sleep(0.1) - assert await arq_redis.zcard(default_queue_name) == 2 + assert await arq_redis.get_queue_size(default_queue_name) == 2 assert worker.jobs_complete == 2 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 - await worker._poll_iteration() + await worker._read_stream_iteration() await asyncio.sleep(0.1) assert await arq_redis.zcard(default_queue_name) == 0 assert worker.jobs_complete == 4 @@ -814,7 +785,7 @@ async def foo(ctx, v): worker.max_burst_jobs = 0 assert len(worker.tasks) == 0 - await worker._poll_iteration() + await worker._read_stream_iteration() assert len(worker.tasks) == 0 @@ -846,7 +817,20 @@ async def foo(ctx, v): caplog.set_level(logging.DEBUG, logger='arq.worker') await arq_redis.enqueue_job('foo', 1, _job_id='testing') worker: Worker = worker(functions=[func(foo, name='foo')]) - await asyncio.gather(*[worker.start_jobs([b'testing']) for _ in range(5)]) + await asyncio.gather( + *[ + worker.start_jobs( + [ + JobMetaInfo( + job_id='testing', + message_id='1', + score=1, + ) + ] + ) + for _ in range(5) + ] + ) # debug(caplog.text) await worker.main() assert c == 1 @@ -1078,7 +1062,7 @@ async def test_worker_retry(mocker, worker_retry, exception_thrown): # baseline await worker.main() - await worker._poll_iteration() + await worker._read_stream_iteration() # spy method handling call_with_retry failure spy = mocker.spy(worker.pool, '_disconnect_raise') @@ -1089,7 +1073,7 @@ async def test_worker_retry(mocker, worker_retry, exception_thrown): # assert exception thrown with pytest.raises(type(exception_thrown)): - await worker._poll_iteration() + await worker._read_stream_iteration() # assert retry counts and no exception thrown during '_disconnect_raise' assert spy.call_count == 4 # retries setting + 1 @@ -1116,7 +1100,7 @@ async def test_worker_crash(mocker, worker, exception_thrown): # baseline await worker.main() - await worker._poll_iteration() + await worker._read_stream_iteration() # spy method handling call_with_retry failure spy = mocker.spy(worker.pool, '_disconnect_raise') @@ -1127,7 +1111,7 @@ async def test_worker_crash(mocker, worker, exception_thrown): # assert exception thrown with pytest.raises(type(exception_thrown)): - await worker._poll_iteration() + await worker._read_stream_iteration() # assert no retry counts and exception thrown during '_disconnect_raise' assert spy.call_count == 1