Skip to content

Commit 3ea7c63

Browse files
authored
Merge pull request #138 from opentensor/fix/thewhaleking/async-websocket-stability
async websocket stability
2 parents 798f1b3 + 0d264ba commit 3ea7c63

File tree

3 files changed

+93
-48
lines changed

3 files changed

+93
-48
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
MultiAccountId,
3333
)
3434
from websockets.asyncio.client import connect
35-
from websockets.exceptions import ConnectionClosed
35+
from websockets.exceptions import ConnectionClosed, WebSocketException
3636

3737
from async_substrate_interface.const import SS58_FORMAT
3838
from async_substrate_interface.errors import (
@@ -535,6 +535,8 @@ def __init__(
535535
self._open_subscriptions = 0
536536
self._options = options if options else {}
537537
self._log_raw_websockets = _log_raw_websockets
538+
self._is_connecting = False
539+
self._is_closing = False
538540

539541
try:
540542
now = asyncio.get_running_loop().time()
@@ -560,38 +562,63 @@ async def __aenter__(self):
560562
async def loop_time() -> float:
561563
return asyncio.get_running_loop().time()
562564

565+
async def _cancel(self):
566+
try:
567+
self._receiving_task.cancel()
568+
await self._receiving_task
569+
await self.ws.close()
570+
except (
571+
AttributeError,
572+
asyncio.CancelledError,
573+
WebSocketException,
574+
):
575+
pass
576+
except Exception as e:
577+
logger.warning(
578+
f"{e} encountered while trying to close websocket connection."
579+
)
580+
563581
async def connect(self, force=False):
564-
now = await self.loop_time()
565-
self.last_received = now
566-
self.last_sent = now
567-
if self._exit_task:
568-
self._exit_task.cancel()
569-
async with self._lock:
570-
if not self._initialized or force:
571-
try:
572-
self._receiving_task.cancel()
573-
await self._receiving_task
574-
await self.ws.close()
575-
except (AttributeError, asyncio.CancelledError):
576-
pass
577-
self.ws = await asyncio.wait_for(
578-
connect(self.ws_url, **self._options), timeout=10
579-
)
580-
self._receiving_task = asyncio.create_task(self._start_receiving())
581-
self._initialized = True
582+
self._is_connecting = True
583+
try:
584+
now = await self.loop_time()
585+
self.last_received = now
586+
self.last_sent = now
587+
if self._exit_task:
588+
self._exit_task.cancel()
589+
if not self._is_closing:
590+
if not self._initialized or force:
591+
try:
592+
await asyncio.wait_for(self._cancel(), timeout=10.0)
593+
except asyncio.TimeoutError:
594+
pass
595+
596+
self.ws = await asyncio.wait_for(
597+
connect(self.ws_url, **self._options), timeout=10.0
598+
)
599+
self._receiving_task = asyncio.get_running_loop().create_task(
600+
self._start_receiving()
601+
)
602+
self._initialized = True
603+
finally:
604+
self._is_connecting = False
582605

583606
async def __aexit__(self, exc_type, exc_val, exc_tb):
584-
async with self._lock: # TODO is this actually what I want to happen?
585-
self._in_use -= 1
586-
if self._exit_task is not None:
587-
self._exit_task.cancel()
588-
try:
589-
await self._exit_task
590-
except asyncio.CancelledError:
591-
pass
592-
if self._in_use == 0 and self.ws is not None:
593-
self._open_subscriptions = 0
594-
self._exit_task = asyncio.create_task(self._exit_with_timer())
607+
self._is_closing = True
608+
try:
609+
if not self._is_connecting:
610+
self._in_use -= 1
611+
if self._exit_task is not None:
612+
self._exit_task.cancel()
613+
try:
614+
await self._exit_task
615+
except asyncio.CancelledError:
616+
pass
617+
if self._in_use == 0 and self.ws is not None:
618+
self._open_subscriptions = 0
619+
self._exit_task = asyncio.create_task(self._exit_with_timer())
620+
finally:
621+
self._is_closing = False
595622

596623
async def _exit_with_timer(self):
597624
"""
@@ -605,16 +632,15 @@ async def _exit_with_timer(self):
605632
pass
606633

607634
async def shutdown(self):
608-
async with self._lock:
609-
try:
610-
self._receiving_task.cancel()
611-
await self._receiving_task
612-
await self.ws.close()
613-
except (AttributeError, asyncio.CancelledError):
614-
pass
615-
self.ws = None
616-
self._initialized = False
617-
self._receiving_task = None
635+
self._is_closing = True
636+
try:
637+
await asyncio.wait_for(self._cancel(), timeout=10.0)
638+
except asyncio.TimeoutError:
639+
pass
640+
self.ws = None
641+
self._initialized = False
642+
self._receiving_task = None
643+
self._is_closing = False
618644

619645
async def _recv(self) -> None:
620646
try:
@@ -624,10 +650,6 @@ async def _recv(self) -> None:
624650
raw_websocket_logger.debug(f"WEBSOCKET_RECEIVE> {recd.decode()}")
625651
response = json.loads(recd)
626652
self.last_received = await self.loop_time()
627-
async with self._lock:
628-
# note that these 'subscriptions' are all waiting sent messages which have not received
629-
# responses, and thus are not the same as RPC 'subscriptions', which are unique
630-
self._open_subscriptions -= 1
631653
if "id" in response:
632654
self._received[response["id"]] = response
633655
self._in_use_ids.remove(response["id"])
@@ -647,8 +669,7 @@ async def _start_receiving(self):
647669
except asyncio.CancelledError:
648670
pass
649671
except ConnectionClosed:
650-
async with self._lock:
651-
await self.connect(force=True)
672+
await self.connect(force=True)
652673

653674
async def send(self, payload: dict) -> int:
654675
"""
@@ -674,8 +695,7 @@ async def send(self, payload: dict) -> int:
674695
self.last_sent = await self.loop_time()
675696
return original_id
676697
except (ConnectionClosed, ssl.SSLError, EOFError):
677-
async with self._lock:
678-
await self.connect(force=True)
698+
await self.connect(force=True)
679699

680700
async def retrieve(self, item_id: int) -> Optional[dict]:
681701
"""
@@ -710,6 +730,7 @@ def __init__(
710730
retry_timeout: float = 60.0,
711731
_mock: bool = False,
712732
_log_raw_websockets: bool = False,
733+
ws_shutdown_timer: float = 5.0,
713734
):
714735
"""
715736
The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
@@ -728,6 +749,7 @@ def __init__(
728749
retry_timeout: how to long wait since the last ping to retry the RPC request
729750
_mock: whether to use mock version of the subtensor interface
730751
_log_raw_websockets: whether to log raw websocket requests during RPC requests
752+
ws_shutdown_timer: how long after the last connection your websocket should close
731753
732754
"""
733755
self.max_retries = max_retries
@@ -744,6 +766,7 @@ def __init__(
744766
"max_size": self.ws_max_size,
745767
"write_limit": 2**16,
746768
},
769+
shutdown_timer=ws_shutdown_timer,
747770
)
748771
else:
749772
self.ws = AsyncMock(spec=Websocket)

async_substrate_interface/substrate_addons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def __init__(
264264
_mock: bool = False,
265265
_log_raw_websockets: bool = False,
266266
archive_nodes: Optional[list[str]] = None,
267+
ws_shutdown_timer: float = 5.0,
267268
):
268269
fallback_chains = fallback_chains or []
269270
archive_nodes = archive_nodes or []
@@ -291,6 +292,7 @@ def __init__(
291292
retry_timeout=retry_timeout,
292293
max_retries=max_retries,
293294
_log_raw_websockets=_log_raw_websockets,
295+
ws_shutdown_timer=ws_shutdown_timer,
294296
)
295297
self._original_methods = {
296298
method: getattr(self, method) for method in RETRY_METHODS

tests/unit_tests/asyncio_/test_substrate_interface.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from unittest.mock import AsyncMock, MagicMock
23

34
import pytest
@@ -91,3 +92,22 @@ async def test_runtime_call(monkeypatch):
9192
substrate.rpc_request.assert_any_call(
9293
"state_call", ["SubstrateApi_SubstrateMethod", "", None]
9394
)
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_websocket_shutdown_timer():
99+
# using default ws shutdown timer of 5.0 seconds
100+
async with AsyncSubstrateInterface("wss://lite.sub.latent.to:443") as substrate:
101+
await substrate.get_chain_head()
102+
await asyncio.sleep(6)
103+
assert (
104+
substrate.ws._initialized is False
105+
) # connection should have closed automatically
106+
107+
# using custom ws shutdown timer of 10.0 seconds
108+
async with AsyncSubstrateInterface(
109+
"wss://lite.sub.latent.to:443", ws_shutdown_timer=10.0
110+
) as substrate:
111+
await substrate.get_chain_head()
112+
await asyncio.sleep(6) # same sleep time as before
113+
assert substrate.ws._initialized is True # connection should still be open

0 commit comments

Comments
 (0)