@@ -247,7 +247,7 @@ def __init__(self, socket_path: Optional[str] = None, auto_reconnect: bool = Fal
247
247
self ._socket_path = socket_path
248
248
self ._auto_reconnect = auto_reconnect
249
249
self ._pubsub = _AIOPubSub (self )
250
- self ._subscriptions = 0
250
+ self ._subscriptions = set ()
251
251
self ._main_future = None
252
252
self ._reconnect_future = None
253
253
@@ -355,7 +355,7 @@ async def connect(self) -> 'Connection':
355
355
self ._sub_fd = self ._sub_socket .fileno ()
356
356
self ._loop .add_reader (self ._sub_fd , self ._message_reader )
357
357
358
- await self ._subscribe ( self ._subscriptions , force = True )
358
+ await self .subscribe ( list ( self ._subscriptions ) , force = True )
359
359
360
360
return self
361
361
@@ -419,25 +419,45 @@ async def _message(self, message_type: MessageType, payload: str = '') -> bytes:
419
419
420
420
return message
421
421
422
- async def _subscribe (self , events : Union [EventType , int ], force = False ):
422
+ async def subscribe (self , events : Union [List [Event ], List [str ]], force : bool = False ):
423
+ """Send a ``SUBSCRIBE`` command to the ipc subscription connection and
424
+ await the result. To attach event handlers, use :func:`Connection.on()
425
+ <i3ipc.aio.Connection.on()>`. Calling this is only needed if you want
426
+ to be notified when events will start coming in.
427
+
428
+ :ivar events: A list of events to subscribe to. Currently you cannot
429
+ subscribe to detailed events.
430
+ :vartype events: list(:class:`Event <i3ipc.Event>`) or list(str)
431
+ :ivar force: If ``False``, the message will not be sent if this
432
+ connection is already subscribed to the event.
433
+ :vartype force: bool
434
+ """
423
435
if not events :
424
436
return
425
437
426
- if type (events ) is int :
427
- events = EventType (events )
438
+ if type (events ) is not list :
439
+ raise TypeError ('events must be a list of events' )
440
+
441
+ subscriptions = set ()
442
+
443
+ for e in events :
444
+ e = Event (e )
445
+ if e not in Event ._subscribable_events :
446
+ correct_event = str .split (e .value , '::' )[0 ].upper ()
447
+ raise ValueError (
448
+ f'only nondetailed events are subscribable (use Event.{ correct_event } )' )
449
+ subscriptions .add (e )
428
450
429
451
if not force :
430
- new_subscriptions = EventType (self ._subscriptions ^ events . value )
431
- else :
432
- new_subscriptions = events
452
+ subscriptions = subscriptions . difference (self ._subscriptions )
453
+ if not subscriptions :
454
+ return
433
455
434
- if not new_subscriptions :
435
- return
456
+ self ._subscriptions .update (subscriptions )
457
+
458
+ payload = json .dumps ([s .value for s in subscriptions ])
436
459
437
- self ._subscriptions |= new_subscriptions .value
438
- event_list = new_subscriptions .to_list ()
439
- await self ._loop .sock_sendall (self ._sub_socket ,
440
- _pack (MessageType .SUBSCRIBE , json .dumps (event_list )))
460
+ await self ._loop .sock_sendall (self ._sub_socket , _pack (MessageType .SUBSCRIBE , payload ))
441
461
442
462
def on (self , event : Union [Event , str ], handler : Callable [['Connection' , IpcBaseEvent ], None ]):
443
463
"""Subscribe to the event and call the handler when it is emitted by
@@ -456,9 +476,8 @@ def on(self, event: Union[Event, str], handler: Callable[['Connection', IpcBaseE
456
476
if event .count ('::' ) > 0 :
457
477
[event , __ ] = event .split ('::' )
458
478
459
- event_type = EventType .from_string (event )
460
479
self ._pubsub .subscribe (event , handler )
461
- asyncio .ensure_future (self ._subscribe ( event_type ))
480
+ asyncio .ensure_future (self .subscribe ([ event ] ))
462
481
463
482
def off (self , handler : Callable [['Connection' , IpcBaseEvent ], None ]):
464
483
"""Unsubscribe the handler from being called on ipc events.
0 commit comments