32
32
MultiAccountId ,
33
33
)
34
34
from websockets .asyncio .client import connect
35
- from websockets .exceptions import ConnectionClosed
35
+ from websockets .exceptions import ConnectionClosed , WebSocketException
36
36
37
37
from async_substrate_interface .const import SS58_FORMAT
38
38
from async_substrate_interface .errors import (
@@ -535,6 +535,8 @@ def __init__(
535
535
self ._open_subscriptions = 0
536
536
self ._options = options if options else {}
537
537
self ._log_raw_websockets = _log_raw_websockets
538
+ self ._is_connecting = False
539
+ self ._is_closing = False
538
540
539
541
try :
540
542
now = asyncio .get_running_loop ().time ()
@@ -560,38 +562,63 @@ async def __aenter__(self):
560
562
async def loop_time () -> float :
561
563
return asyncio .get_running_loop ().time ()
562
564
565
+ async def _cancel (self ):
566
+ try :
567
+ self ._receiving_task .cancel ()
568
+ await self ._receiving_task
569
+ await self .ws .close ()
570
+ except (
571
+ AttributeError ,
572
+ asyncio .CancelledError ,
573
+ WebSocketException ,
574
+ ):
575
+ pass
576
+ except Exception as e :
577
+ logger .warning (
578
+ f"{ e } encountered while trying to close websocket connection."
579
+ )
580
+
563
581
async def connect (self , force = False ):
564
- now = await self .loop_time ()
565
- self .last_received = now
566
- self .last_sent = now
567
- if self ._exit_task :
568
- self ._exit_task .cancel ()
569
- async with self ._lock :
570
- if not self ._initialized or force :
571
- try :
572
- self ._receiving_task .cancel ()
573
- await self ._receiving_task
574
- await self .ws .close ()
575
- except (AttributeError , asyncio .CancelledError ):
576
- pass
577
- self .ws = await asyncio .wait_for (
578
- connect (self .ws_url , ** self ._options ), timeout = 10
579
- )
580
- self ._receiving_task = asyncio .create_task (self ._start_receiving ())
581
- self ._initialized = True
582
+ self ._is_connecting = True
583
+ try :
584
+ now = await self .loop_time ()
585
+ self .last_received = now
586
+ self .last_sent = now
587
+ if self ._exit_task :
588
+ self ._exit_task .cancel ()
589
+ if not self ._is_closing :
590
+ if not self ._initialized or force :
591
+ try :
592
+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
593
+ except asyncio .TimeoutError :
594
+ pass
595
+
596
+ self .ws = await asyncio .wait_for (
597
+ connect (self .ws_url , ** self ._options ), timeout = 10.0
598
+ )
599
+ self ._receiving_task = asyncio .get_running_loop ().create_task (
600
+ self ._start_receiving ()
601
+ )
602
+ self ._initialized = True
603
+ finally :
604
+ self ._is_connecting = False
582
605
583
606
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
584
- async with self ._lock : # TODO is this actually what I want to happen?
585
- self ._in_use -= 1
586
- if self ._exit_task is not None :
587
- self ._exit_task .cancel ()
588
- try :
589
- await self ._exit_task
590
- except asyncio .CancelledError :
591
- pass
592
- if self ._in_use == 0 and self .ws is not None :
593
- self ._open_subscriptions = 0
594
- self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
607
+ self ._is_closing = True
608
+ try :
609
+ if not self ._is_connecting :
610
+ self ._in_use -= 1
611
+ if self ._exit_task is not None :
612
+ self ._exit_task .cancel ()
613
+ try :
614
+ await self ._exit_task
615
+ except asyncio .CancelledError :
616
+ pass
617
+ if self ._in_use == 0 and self .ws is not None :
618
+ self ._open_subscriptions = 0
619
+ self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
620
+ finally :
621
+ self ._is_closing = False
595
622
596
623
async def _exit_with_timer (self ):
597
624
"""
@@ -605,16 +632,15 @@ async def _exit_with_timer(self):
605
632
pass
606
633
607
634
async def shutdown (self ):
608
- async with self ._lock :
609
- try :
610
- self ._receiving_task .cancel ()
611
- await self ._receiving_task
612
- await self .ws .close ()
613
- except (AttributeError , asyncio .CancelledError ):
614
- pass
615
- self .ws = None
616
- self ._initialized = False
617
- self ._receiving_task = None
635
+ self ._is_closing = True
636
+ try :
637
+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
638
+ except asyncio .TimeoutError :
639
+ pass
640
+ self .ws = None
641
+ self ._initialized = False
642
+ self ._receiving_task = None
643
+ self ._is_closing = False
618
644
619
645
async def _recv (self ) -> None :
620
646
try :
@@ -624,10 +650,6 @@ async def _recv(self) -> None:
624
650
raw_websocket_logger .debug (f"WEBSOCKET_RECEIVE> { recd .decode ()} " )
625
651
response = json .loads (recd )
626
652
self .last_received = await self .loop_time ()
627
- async with self ._lock :
628
- # note that these 'subscriptions' are all waiting sent messages which have not received
629
- # responses, and thus are not the same as RPC 'subscriptions', which are unique
630
- self ._open_subscriptions -= 1
631
653
if "id" in response :
632
654
self ._received [response ["id" ]] = response
633
655
self ._in_use_ids .remove (response ["id" ])
@@ -647,8 +669,7 @@ async def _start_receiving(self):
647
669
except asyncio .CancelledError :
648
670
pass
649
671
except ConnectionClosed :
650
- async with self ._lock :
651
- await self .connect (force = True )
672
+ await self .connect (force = True )
652
673
653
674
async def send (self , payload : dict ) -> int :
654
675
"""
@@ -674,8 +695,7 @@ async def send(self, payload: dict) -> int:
674
695
self .last_sent = await self .loop_time ()
675
696
return original_id
676
697
except (ConnectionClosed , ssl .SSLError , EOFError ):
677
- async with self ._lock :
678
- await self .connect (force = True )
698
+ await self .connect (force = True )
679
699
680
700
async def retrieve (self , item_id : int ) -> Optional [dict ]:
681
701
"""
@@ -710,6 +730,7 @@ def __init__(
710
730
retry_timeout : float = 60.0 ,
711
731
_mock : bool = False ,
712
732
_log_raw_websockets : bool = False ,
733
+ ws_shutdown_timer : float = 5.0 ,
713
734
):
714
735
"""
715
736
The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
@@ -728,6 +749,7 @@ def __init__(
728
749
retry_timeout: how to long wait since the last ping to retry the RPC request
729
750
_mock: whether to use mock version of the subtensor interface
730
751
_log_raw_websockets: whether to log raw websocket requests during RPC requests
752
+ ws_shutdown_timer: how long after the last connection your websocket should close
731
753
732
754
"""
733
755
self .max_retries = max_retries
@@ -744,6 +766,7 @@ def __init__(
744
766
"max_size" : self .ws_max_size ,
745
767
"write_limit" : 2 ** 16 ,
746
768
},
769
+ shutdown_timer = ws_shutdown_timer ,
747
770
)
748
771
else :
749
772
self .ws = AsyncMock (spec = Websocket )
0 commit comments