Skip to content

Commit

Permalink
Merge pull request #47 from epandurski/master
Browse files Browse the repository at this point in the history
Look for a DEACTIVATED marker file
  • Loading branch information
epandurski authored Oct 7, 2024
2 parents b285c3f + cad708a commit c9cbb2c
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 0 deletions.
1 change: 1 addition & 0 deletions development.env
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ APP_ASSOCIATED_LOGGERS=swpt_pythonlib.flask_signalbus.signalbus_cli
APP_SSL_HANDSHAKE_TIMEOUT=5
APP_MAX_CACHED_PEERS=5000
APP_PEERS_CACHE_SECONDS=600
APP_PEERS_CHECK_SECONDS=3600
APP_FILE_READ_THREADS=5
APP_RMQ_CONNECTION_TIMEOUT_SECONDS=10
APP_RMQ_CONFIRMATION_TIMEOUT_SECONDS=20
Expand Down
7 changes: 7 additions & 0 deletions swpt_stomp/peer_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,13 @@ async def _get_peer_data(
except FileNotFoundError: # pragma: nocover
return None

try:
await self._read_file(f"{dir}/DEACTIVATED")
except FileNotFoundError:
pass
else: # pragma: nocover
return None

try:
# Peers that do not have a file with the name "ACTIVE" in their
# corresponding directories are considered inactive. The
Expand Down
10 changes: 10 additions & 0 deletions swpt_stomp/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def _NO_TM(m: "RmqMessage") -> Message:
)


async def _NEVER() -> bool:
return False


APP_RMQ_CONNECTION_TIMEOUT_SECONDS = float(
os.environ.get(
"APP_RMQ_CONNECTION_TIMEOUT_SECONDS",
Expand Down Expand Up @@ -124,6 +128,7 @@ async def publish_to_exchange(
confirmation_timeout: float = APP_RMQ_CONFIRMATION_TIMEOUT_SECONDS,
connection_timeout: float = APP_RMQ_CONNECTION_TIMEOUT_SECONDS,
channel: Optional[AbstractChannel] = None,
is_peer_deactivated: Callable[[], Awaitable[bool]] = _NEVER,
) -> None:
"""Publishes messages to a RabbitMQ exchange.
Expand Down Expand Up @@ -156,6 +161,7 @@ async def publish_messages(ch: AbstractChannel) -> None:
exchange_name=exchange_name,
preprocess_message=preprocess_message,
confirmation_timeout=confirmation_timeout,
is_peer_deactivated=is_peer_deactivated,
)

try:
Expand Down Expand Up @@ -299,6 +305,7 @@ async def _publish_to_exchange(
exchange_name: str,
preprocess_message: Callable[[Message], Awaitable[RmqMessage]],
confirmation_timeout: float,
is_peer_deactivated: Callable[[], Awaitable[bool]],
) -> None:
exchange = await channel.get_exchange(exchange_name, ensure=False)
deliveries: deque[_Delivery] = deque()
Expand Down Expand Up @@ -343,6 +350,9 @@ async def deliver_message(message: Message) -> None:

async def publish_messages() -> None:
while message := await recv_queue.get():
if await is_peer_deactivated():
raise ServerError("The peer has been deactivated.")

delivery = _Delivery(message.id)
mark_as_confirmed = partial(on_confirmation, delivery)

Expand Down
20 changes: 20 additions & 0 deletions swpt_stomp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import logging
import sys
import time
import random
import os
import os.path
import asyncio
Expand Down Expand Up @@ -69,6 +71,9 @@
APP_MAX_CONNECTIONS_PER_PEER = int(
os.environ.get("APP_MAX_CONNECTIONS_PER_PEER", "10")
)
APP_PEERS_CHECK_SECONDS = float(
os.environ.get("APP_PEERS_CHECK_SECONDS", "3600")
)

_EXCHANGE_NAMES = {
NodeType.AA: "accounts_in",
Expand Down Expand Up @@ -129,6 +134,20 @@ async def publish(transport: asyncio.Transport) -> None:
)
raise
else:
next_peer_check_at = (
time.time() + random.random() * APP_PEERS_CHECK_SECONDS
)

async def is_peer_deactivated() -> bool:
nonlocal next_peer_check_at
now = time.time()
if now < next_peer_check_at:
return False
else: # pragma: no cover
next_peer_check_at = now + APP_PEERS_CHECK_SECONDS
peer_data = await db.get_peer_data(peer_serial_number)
return peer_data is None

with _allowed_peer_connection(peer_data.node_id):
await publish_to_exchange(
send_queue,
Expand All @@ -139,6 +158,7 @@ async def publish(transport: asyncio.Transport) -> None:
preprocess_message, owner_node_data, peer_data
),
channel=channel,
is_peer_deactivated=is_peer_deactivated,
)

return StompServer(
Expand Down
51 changes: 51 additions & 0 deletions tests/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,54 @@ async def preprocess_message(m):
assert isinstance(m, ServerError)
assert m.error_message == "Test error"
assert send_queue.empty()


@pytest.mark.asyncio
async def test_publish_deactivated_peer(rmq_url):
await ready_queue(rmq_url)
loop = asyncio.get_running_loop()
send_queue = asyncio.Queue(5)
recv_queue = WatermarkQueue(5)

message = Message(
id="1",
type="TestMessage",
body=bytearray(b"Test message"),
content_type="text/plain",
)
await recv_queue.put(message)

async def preprocess_message(m):
return RmqMessage(
id=m.id,
body=bytes(m.body),
headers={
"message-type": m.type,
"debtor-id": 1,
"creditor-id": 2,
"coordinator-id": 3,
},
type=m.type,
content_type=m.content_type,
routing_key="test_stomp",
)

async def always() -> bool:
return True

publish_task = loop.create_task(
publish_to_exchange(
send_queue,
recv_queue,
url=rmq_url,
exchange_name="",
preprocess_message=preprocess_message,
is_peer_deactivated=always,
)
)

await asyncio.wait_for(publish_task, 10.0)
m = await send_queue.get()
assert isinstance(m, ServerError)
assert m.error_message == "The peer has been deactivated."
assert send_queue.empty()

0 comments on commit c9cbb2c

Please sign in to comment.