Skip to content

Commit 336da87

Browse files
authored
Merge pull request #244 from vladak/reconnect_restoration
reconnect restoration
2 parents 1d8ef3f + 933b1cb commit 336da87

File tree

3 files changed

+261
-4
lines changed

3 files changed

+261
-4
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
219219
if port:
220220
self.port = port
221221

222+
self.session_id = None
223+
222224
# define client identifier
223225
if client_id:
224226
# user-defined client_id MAY allow client_id's > 23 bytes or
@@ -542,6 +544,7 @@ def _connect( # noqa: PLR0912, PLR0913, PLR0915, Too many branches, Too many ar
542544
is_ssl=self._is_ssl,
543545
ssl_context=self._ssl_context,
544546
)
547+
self.session_id = session_id
545548
self._backwards_compatible_sock = not hasattr(self._sock, "recv_into")
546549

547550
fixed_header = bytearray([0x10])
@@ -954,11 +957,18 @@ def reconnect(self, resub_topics: bool = True) -> int:
954957
"""
955958

956959
self.logger.debug("Attempting to reconnect with MQTT broker")
957-
ret = self.connect()
960+
subscribed_topics = []
961+
if self.is_connected():
962+
# disconnect() will reset subscribed topics so stash them now.
963+
if resub_topics:
964+
subscribed_topics = self._subscribed_topics.copy()
965+
self.disconnect()
966+
967+
ret = self.connect(session_id=self.session_id)
958968
self.logger.debug("Reconnected with broker")
959-
if resub_topics:
969+
970+
if resub_topics and subscribed_topics:
960971
self.logger.debug("Attempting to resubscribe to previously subscribed topics.")
961-
subscribed_topics = self._subscribed_topics.copy()
962972
self._subscribed_topics = []
963973
while subscribed_topics:
964974
feed = subscribed_topics.pop()

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@pytest.fixture(autouse=True)
1212
def reset_connection_manager(monkeypatch):
13-
"""Reset the ConnectionManager, since it's a singlton and will hold data"""
13+
"""Reset the ConnectionManager, since it's a singleton and will hold data"""
1414
monkeypatch.setattr(
1515
"adafruit_minimqtt.adafruit_minimqtt.get_connection_manager",
1616
adafruit_connection_manager.ConnectionManager,

tests/test_reconnect.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# SPDX-FileCopyrightText: 2025 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
"""reconnect tests"""
6+
7+
import logging
8+
import ssl
9+
import sys
10+
11+
import pytest
12+
from mocket import Mocket
13+
14+
import adafruit_minimqtt.adafruit_minimqtt as MQTT
15+
16+
if not sys.implementation.name == "circuitpython":
17+
from typing import Optional
18+
19+
from circuitpython_typing.socket import (
20+
SocketType,
21+
SSLContextType,
22+
)
23+
24+
25+
class FakeConnectionManager:
26+
"""
27+
Fake ConnectionManager class
28+
"""
29+
30+
def __init__(self, socket):
31+
self._socket = socket
32+
self.close_cnt = 0
33+
34+
def get_socket( # noqa: PLR0913, Too many arguments
35+
self,
36+
host: str,
37+
port: int,
38+
proto: str,
39+
session_id: Optional[str] = None,
40+
*,
41+
timeout: float = 1.0,
42+
is_ssl: bool = False,
43+
ssl_context: Optional[SSLContextType] = None,
44+
) -> SocketType:
45+
"""
46+
Return the specified socket.
47+
"""
48+
return self._socket
49+
50+
def close_socket(self, socket) -> None:
51+
self.close_cnt += 1
52+
53+
54+
def handle_subscribe(client, user_data, topic, qos):
55+
"""
56+
Record topics into user data.
57+
"""
58+
assert topic
59+
assert user_data["topics"] is not None
60+
assert qos == 0
61+
62+
user_data["topics"].append(topic)
63+
64+
65+
def handle_disconnect(client, user_data, zero):
66+
"""
67+
Record disconnect.
68+
"""
69+
70+
user_data["disconnect"] = True
71+
72+
73+
# The MQTT packet contents below were captured using Mosquitto client+server.
74+
testdata = [
75+
(
76+
[],
77+
bytearray(
78+
[
79+
0x20, # CONNACK
80+
0x02,
81+
0x00,
82+
0x00,
83+
0x90, # SUBACK
84+
0x03,
85+
0x00,
86+
0x01,
87+
0x00,
88+
0x20, # CONNACK
89+
0x02,
90+
0x00,
91+
0x00,
92+
0x90, # SUBACK
93+
0x03,
94+
0x00,
95+
0x02,
96+
0x00,
97+
]
98+
),
99+
),
100+
(
101+
[("foo/bar", 0)],
102+
bytearray(
103+
[
104+
0x20, # CONNACK
105+
0x02,
106+
0x00,
107+
0x00,
108+
0x90, # SUBACK
109+
0x03,
110+
0x00,
111+
0x01,
112+
0x00,
113+
0x20, # CONNACK
114+
0x02,
115+
0x00,
116+
0x00,
117+
0x90, # SUBACK
118+
0x03,
119+
0x00,
120+
0x02,
121+
0x00,
122+
]
123+
),
124+
),
125+
(
126+
[("foo/bar", 0), ("bah", 0)],
127+
bytearray(
128+
[
129+
0x20, # CONNACK
130+
0x02,
131+
0x00,
132+
0x00,
133+
0x90, # SUBACK
134+
0x03,
135+
0x00,
136+
0x01,
137+
0x00,
138+
0x00,
139+
0x20, # CONNACK
140+
0x02,
141+
0x00,
142+
0x00,
143+
0x90, # SUBACK
144+
0x03,
145+
0x00,
146+
0x02,
147+
0x00,
148+
0x90, # SUBACK
149+
0x03,
150+
0x00,
151+
0x03,
152+
0x00,
153+
]
154+
),
155+
),
156+
]
157+
158+
159+
@pytest.mark.parametrize(
160+
"topics,to_send",
161+
testdata,
162+
ids=[
163+
"no_topic",
164+
"single_topic",
165+
"multi_topic",
166+
],
167+
)
168+
def test_reconnect(topics, to_send) -> None:
169+
"""
170+
Test reconnect() handling, mainly that it performs disconnect on already connected socket.
171+
172+
Nothing will travel over the wire, it is all fake.
173+
"""
174+
logging.basicConfig()
175+
logger = logging.getLogger(__name__)
176+
logger.setLevel(logging.DEBUG)
177+
178+
host = "localhost"
179+
port = 1883
180+
181+
user_data = {"topics": [], "disconnect": False}
182+
mqtt_client = MQTT.MQTT(
183+
broker=host,
184+
port=port,
185+
ssl_context=ssl.create_default_context(),
186+
connect_retries=1,
187+
user_data=user_data,
188+
)
189+
190+
mocket = Mocket(to_send)
191+
mqtt_client._connection_manager = FakeConnectionManager(mocket)
192+
mqtt_client.connect()
193+
194+
mqtt_client.logger = logger
195+
196+
if topics:
197+
logger.info(f"subscribing to {topics}")
198+
mqtt_client.subscribe(topics)
199+
200+
logger.info("reconnecting")
201+
mqtt_client.on_subscribe = handle_subscribe
202+
mqtt_client.on_disconnect = handle_disconnect
203+
mqtt_client.reconnect()
204+
205+
assert user_data.get("disconnect") == True
206+
assert mqtt_client._connection_manager.close_cnt == 1
207+
assert set(user_data.get("topics")) == set([t[0] for t in topics])
208+
209+
210+
def test_reconnect_not_connected() -> None:
211+
"""
212+
Test reconnect() handling not connected.
213+
"""
214+
logging.basicConfig()
215+
logger = logging.getLogger(__name__)
216+
logger.setLevel(logging.DEBUG)
217+
218+
host = "localhost"
219+
port = 1883
220+
221+
user_data = {"topics": [], "disconnect": False}
222+
mqtt_client = MQTT.MQTT(
223+
broker=host,
224+
port=port,
225+
ssl_context=ssl.create_default_context(),
226+
connect_retries=1,
227+
user_data=user_data,
228+
)
229+
230+
mocket = Mocket(
231+
bytearray(
232+
[
233+
0x20, # CONNACK
234+
0x02,
235+
0x00,
236+
0x00,
237+
]
238+
)
239+
)
240+
mqtt_client._connection_manager = FakeConnectionManager(mocket)
241+
242+
mqtt_client.logger = logger
243+
mqtt_client.on_disconnect = handle_disconnect
244+
mqtt_client.reconnect()
245+
246+
assert user_data.get("disconnect") == False
247+
assert mqtt_client._connection_manager.close_cnt == 0

0 commit comments

Comments
 (0)