Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions celery_batches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Collection,
Dict,
Iterable,
NoReturn,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -185,7 +184,7 @@ def __init__(self) -> None:
self._tref: Optional[Timer] = None
self._pool: BasePool = None

def run(self, *args: Any, **kwargs: Any) -> NoReturn:
def run(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError("must implement run(requests)")

def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable:
Expand All @@ -201,6 +200,9 @@ def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable
connection_errors = consumer.connection_errors

eventer = consumer.event_dispatcher
events = eventer and eventer.enabled
send_event = eventer and eventer.send
task_sends_events = events and task.send_events

Request = symbol_by_name(task.Request)
# Celery 5.1 added the app argument to create_request_cls.
Expand Down Expand Up @@ -256,6 +258,20 @@ def task_message_handler(

signals.task_received.send(sender=consumer, request=req)

if task_sends_events:
send_event(
"task-received",
uuid=req.id,
name=req.name,
args=req.argsrepr,
kwargs=req.kwargsrepr,
root_id=req.root_id,
parent_id=req.parent_id,
retries=req.request_dict.get("retries", 0),
eta=req.eta and req.eta.isoformat(),
expires=req.expires and req.expires.isoformat(),
)

if self._tref is None: # first request starts flush timer.
self._tref = timer.call_repeatedly(self.flush_interval, flush_buffer)

Expand Down Expand Up @@ -358,10 +374,17 @@ def flush(self, requests: Collection[Request]) -> Any:
def on_accepted(pid: int, time_accepted: float) -> None:
for req in acks_early:
req.acknowledge()
for request in requests:
request.send_event("task-started")

def on_return(result: Optional[Any]) -> None:
for req in acks_late:
req.acknowledge()
for request in requests:
runtime = 0
if isinstance(result, int):
runtime = result
request.send_event("task-succeeded", result=None, runtime=runtime)

return self._pool.apply_async(
apply_batches_task,
Expand Down
54 changes: 48 additions & 6 deletions celery_batches/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Mimics some of the functionality found in celery.app.trace.trace_task.
"""

from typing import TYPE_CHECKING, Any, List, Tuple
from typing import TYPE_CHECKING, Any, List, Tuple, Union

from celery import signals, states
from celery._state import _task_stack
Expand All @@ -22,8 +22,11 @@
send_prerun = signals.task_prerun.send
send_postrun = signals.task_postrun.send
send_success = signals.task_success.send
send_failure = signals.task_failure.send
send_revoked = signals.task_revoked.send
SUCCESS = states.SUCCESS
FAILURE = states.FAILURE
REVOKED = states.REVOKED


def apply_batches_task(
Expand All @@ -38,6 +41,14 @@ def apply_batches_task(
prerun_receivers = signals.task_prerun.receivers
postrun_receivers = signals.task_postrun.receivers
success_receivers = signals.task_success.receivers
failure_receivers = signals.task_failure.receivers
revoked_receivers = signals.task_revoked.receivers

logger.debug(f"Debug: prerun_receivers: {prerun_receivers}")
logger.debug(f"Debug: postrun_receivers: {postrun_receivers}")
logger.debug(f"Debug: success_receivers: {success_receivers}")
logger.debug(f"Debug: failure_receivers: {failure_receivers}")
logger.debug(f"Debug: revoked_receivers: {revoked_receivers}")

# Corresponds to multiple requests, so generate a new UUID.
task_id = uuid()
Expand All @@ -46,25 +57,56 @@ def apply_batches_task(
task_request = Context(loglevel=loglevel, logfile=logfile)
push_request(task_request)

result: Union[Any, Exception]
state: str

try:
# -*- PRE -*-
if prerun_receivers:
logger.debug("Debug: Sending prerun signal")
send_prerun(sender=task, task_id=task_id, task=task, args=args, kwargs={})

# -*- TRACE -*-
try:
result = task(*args)
state = SUCCESS
state = (
REVOKED
if (hasattr(task.request, "state") and task.request.state == REVOKED)
else SUCCESS
)
except Exception as exc:
result = None
result = exc
state = FAILURE
logger.error("Error: %r", exc, exc_info=True)
else:
if success_receivers:
send_success(sender=task, result=result)
if failure_receivers:
logger.debug("Debug: Sending failure signal")
send_failure(
sender=task,
task_id=task_id,
exception=exc,
args=args,
kwargs={},
einfo=None,
)

# Handle signals based on the state
if state == REVOKED and revoked_receivers:
logger.debug("Debug: Sending revoked signal")
send_revoked(
sender=task,
request=task_request,
terminated=True,
signum=None,
expired=False,
)
elif state == SUCCESS and success_receivers:
logger.debug("Debug: Sending success signal")
send_success(sender=task, result=result)

finally:
try:
if postrun_receivers:
logger.debug("Debug: Sending postrun signal")
send_postrun(
sender=task,
task_id=task_id,
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest-celery~=0.0.0
pytest~=6.2
coverage
pytest-timeout
pytest-asyncio
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ packages =
install_requires = celery>=5.0,<5.5
python_requires = >=3.8

[options.extras_require]
test =
pytest
pytest-asyncio

[flake8]
extend-ignore = E203
max-line-length = 88
Expand Down
85 changes: 85 additions & 0 deletions t/integration/test_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why all the use of asyncio in these tests instead of being more similar to the existing tests?

import logging
from typing import Any, List

from celery_batches import Batches, SimpleRequest

from celery import Celery

import pytest

pytest_plugins = ("pytest_asyncio",)

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


def setup_celery() -> Celery:
app = Celery("myapp")
app.conf.update(
broker_url="memory://",
result_backend="cache+memory://",
task_always_eager=False,
worker_concurrency=1,
worker_prefetch_multiplier=1,
task_create_missing_queues=True,
broker_connection_retry_on_startup=True,
)
return app


celery_app = setup_celery()


@celery_app.task(base=Batches, flush_every=2, flush_interval=0.1)
def add(requests: List[SimpleRequest]) -> int:
"""
Add the first argument of each task.

Marks the result of each task as the sum.
"""
logger.debug(f"Processing {len(requests)} requests")
result = int(
sum(
sum(int(arg) for arg in request.args)
+ sum(int(value) for value in request.kwargs.values())
for request in requests
)
)

for request in requests:
celery_app.backend.mark_as_done(request.id, result, request=request)

logger.debug(f"Finished processing. Result: {result}")
return result


@pytest.mark.asyncio
async def test_tasks_for_add(celery_worker: Any) -> None:
logger.debug("Starting test_tasks_for_add")

# Send tasks
logger.debug("Sending tasks")
result_1 = add.delay(1)
result_2 = add.delay(2)

logger.debug("Waiting for results")
try:
# Wait for the batch to be processed
results = await asyncio.wait_for(
asyncio.gather(
asyncio.get_event_loop().run_in_executor(None, result_1.get),
asyncio.get_event_loop().run_in_executor(None, result_2.get),
),
timeout=5.0,
)
logger.debug(f"Results: {results}")
except asyncio.TimeoutError:
logger.error("Test timed out while waiting for results")
pytest.fail("Test timed out while waiting for results")

# Check results
assert results[0] == 3, f"Expected 3, got {results[0]}"
assert results[1] == 3, f"Expected 3, got {results[1]}"

logger.debug("Test completed successfully")
Loading