From fdd436e4fbb253a2f07fca529d04b5d0b8bad4d5 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Tue, 28 Jan 2025 19:07:35 +0100 Subject: [PATCH 1/8] reduce the use of MMQTTException use it for protocol/network/system level errors only fixes #201 --- adafruit_minimqtt/adafruit_minimqtt.py | 63 ++++++++++++++++---------- tests/test_loop.py | 8 ++-- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 75148ed4..264c8e3d 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -93,13 +93,26 @@ class MMQTTException(Exception): - """MiniMQTT Exception class.""" + """ + MiniMQTT Exception class. + + Raised for various mostly protocol or network/system level errors. + In general, the robust way to recover is to call reconnect(). + """ def __init__(self, error, code=None): super().__init__(error, code) self.code = code +class MMQTTStateError(MMQTTException): + """ + MiniMQTT invalid state error. + + Raised e.g. if a function is called in unexpected state. + """ + + class NullLogger: """Fake logger class that does not do anything""" @@ -163,7 +176,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments self._use_binary_mode = use_binary_mode if recv_timeout <= socket_timeout: - raise MMQTTException("recv_timeout must be strictly greater than socket_timeout") + raise ValueError("recv_timeout must be strictly greater than socket_timeout") self._socket_timeout = socket_timeout self._recv_timeout = recv_timeout @@ -181,7 +194,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments self._reconnect_timeout = float(0) self._reconnect_maximum_backoff = 32 if connect_retries <= 0: - raise MMQTTException("connect_retries must be positive") + raise ValueError("connect_retries must be positive") self._reconnect_attempts_max = connect_retries self.broker = broker @@ -190,7 +203,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments if ( self._password and len(password.encode("utf-8")) > MQTT_TOPIC_LENGTH_LIMIT ): # [MQTT-3.1.3.5] - raise MMQTTException("Password length is too large.") + raise ValueError("Password length is too large.") # The connection will be insecure unless is_ssl is set to True. # If the port is not specified, the security will be set based on the is_ssl parameter. @@ -286,15 +299,15 @@ def will_set( """ self.logger.debug("Setting last will properties") if self._is_connected: - raise MMQTTException("Last Will should only be called before connect().") + raise MMQTTStateError("Last Will should only be called before connect().") # check topic/msg/qos kwargs self._valid_topic(topic) if "+" in topic or "#" in topic: - raise MMQTTException("Publish topic can not contain wildcards.") + raise ValueError("Publish topic can not contain wildcards.") if msg is None: - raise MMQTTException("Message can not be None.") + raise ValueError("Message can not be None.") if isinstance(msg, (int, float)): msg = str(msg).encode("ascii") elif isinstance(msg, str): @@ -302,12 +315,11 @@ def will_set( elif isinstance(msg, bytes): pass else: - raise MMQTTException("Invalid message data type.") + raise ValueError("Invalid message data type.") if len(msg) > MQTT_MSG_MAX_SZ: - raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.") + raise ValueError(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.") self._valid_qos(qos) - assert 0 <= qos <= 1, "Quality of Service Level 2 is unsupported by this library." # fixed header. [3.3.1.2], [3.3.1.3] pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1]) @@ -390,7 +402,7 @@ def username_pw_set(self, username: str, password: Optional[str] = None) -> None """ if self._is_connected: - raise MMQTTException("This method must be called before connect().") + raise MMQTTStateError("This method must be called before connect().") self._username = username if password is not None: self._password = password @@ -670,10 +682,10 @@ def publish( # noqa: PLR0912, Too many branches self._connected() self._valid_topic(topic) if "+" in topic or "#" in topic: - raise MMQTTException("Publish topic can not contain wildcards.") + raise ValueError("Publish topic can not contain wildcards.") # check msg/qos kwargs if msg is None: - raise MMQTTException("Message can not be None.") + raise ValueError("Message can not be None.") if isinstance(msg, (int, float)): msg = str(msg).encode("ascii") elif isinstance(msg, str): @@ -681,10 +693,11 @@ def publish( # noqa: PLR0912, Too many branches elif isinstance(msg, bytes): pass else: - raise MMQTTException("Invalid message data type.") + raise ValueError("Invalid message data type.") if len(msg) > MQTT_MSG_MAX_SZ: - raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.") - assert 0 <= qos <= 1, "Quality of Service Level 2 is unsupported by this library." + raise ValueError(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.") + + self._valid_qos(qos) # fixed header. [3.3.1.2], [3.3.1.3] pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1]) @@ -849,7 +862,7 @@ def unsubscribe( # noqa: PLR0912, Too many branches topics.append(t) for t in topics: if t not in self._subscribed_topics: - raise MMQTTException("Topic must be subscribed to before attempting unsubscribe.") + raise MMQTTStateError("Topic must be subscribed to before attempting unsubscribe.") # Assemble packet self.logger.debug("Sending UNSUBSCRIBE to broker...") fixed_header = bytearray([MQTT_UNSUB]) @@ -959,7 +972,7 @@ def loop(self, timeout: float = 1.0) -> Optional[list[int]]: """ if timeout < self._socket_timeout: - raise MMQTTException( + raise ValueError( f"loop timeout ({timeout}) must be >= " + f"socket timeout ({self._socket_timeout}))" ) @@ -1153,13 +1166,13 @@ def _valid_topic(topic: str) -> None: """ if topic is None: - raise MMQTTException("Topic may not be NoneType") + raise ValueError("Topic may not be NoneType") # [MQTT-4.7.3-1] if not topic: - raise MMQTTException("Topic may not be empty.") + raise ValueError("Topic may not be empty.") # [MQTT-4.7.3-3] if len(topic.encode("utf-8")) > MQTT_TOPIC_LENGTH_LIMIT: - raise MMQTTException("Topic length is too large.") + raise ValueError(f"Encoded topic length is larger than {MQTT_TOPIC_LENGTH_LIMIT}") @staticmethod def _valid_qos(qos_level: int) -> None: @@ -1170,16 +1183,16 @@ def _valid_qos(qos_level: int) -> None: """ if isinstance(qos_level, int): if qos_level < 0 or qos_level > 2: - raise MMQTTException("QoS must be between 1 and 2.") + raise NotImplementedError("QoS must be between 1 and 2.") else: - raise MMQTTException("QoS must be an integer.") + raise ValueError("QoS must be an integer.") def _connected(self) -> None: """Returns MQTT client session status as True if connected, raises - a `MMQTTException` if `False`. + a `MMQTTStateError exception` if `False`. """ if not self.is_connected(): - raise MMQTTException("MiniMQTT is not connected") + raise MMQTTStateError("MiniMQTT is not connected") def is_connected(self) -> bool: """Returns MQTT client session status as True if connected, False diff --git a/tests/test_loop.py b/tests/test_loop.py index 834a0d4f..f64dd18c 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -155,7 +155,7 @@ def test_loop_basic(self) -> None: def test_loop_timeout_vs_socket_timeout(self): """ - loop() should throw MMQTTException if the timeout argument + loop() should throw ValueError if the timeout argument is bigger than the socket timeout. """ mqtt_client = MQTT.MQTT( @@ -167,14 +167,14 @@ def test_loop_timeout_vs_socket_timeout(self): ) mqtt_client.is_connected = lambda: True - with pytest.raises(MQTT.MMQTTException) as context: + with pytest.raises(ValueError) as context: mqtt_client.loop(timeout=0.5) assert "loop timeout" in str(context) def test_loop_is_connected(self): """ - loop() should throw MMQTTException if not connected + loop() should throw MMQTTStateError if not connected """ mqtt_client = MQTT.MQTT( broker="127.0.0.1", @@ -183,7 +183,7 @@ def test_loop_is_connected(self): ssl_context=ssl.create_default_context(), ) - with pytest.raises(MQTT.MMQTTException) as context: + with pytest.raises(MQTT.MMQTTStateError) as context: mqtt_client.loop(timeout=1) assert "not connected" in str(context) From 1778c7cfd6773c6523ef6390bd973929e21bfc0a Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Sun, 9 Feb 2025 23:34:03 +0100 Subject: [PATCH 2/8] fix typo --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 93eefcbd..2b8b03ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ @pytest.fixture(autouse=True) def reset_connection_manager(monkeypatch): - """Reset the ConnectionManager, since it's a singlton and will hold data""" + """Reset the ConnectionManager, since it's a singleton and will hold data""" monkeypatch.setattr( "adafruit_minimqtt.adafruit_minimqtt.get_connection_manager", adafruit_connection_manager.ConnectionManager, From dc361342e4604b9725ed3af6a84c6d4bafeb8302 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Sun, 9 Feb 2025 23:34:37 +0100 Subject: [PATCH 3/8] disconnect on reconnect() if connected fixes #243 --- adafruit_minimqtt/adafruit_minimqtt.py | 11 +- tests/test_reconnect.py | 205 +++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 tests/test_reconnect.py diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 75148ed4..1c63d5ca 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -939,11 +939,18 @@ def reconnect(self, resub_topics: bool = True) -> int: """ self.logger.debug("Attempting to reconnect with MQTT broker") + subscribed_topics = [] + if self.is_connected(): + # disconnect() will reset subscribed topics so stash them now. + if resub_topics: + subscribed_topics = self._subscribed_topics.copy() + self.disconnect() + ret = self.connect() self.logger.debug("Reconnected with broker") - if resub_topics: + + if resub_topics and subscribed_topics: self.logger.debug("Attempting to resubscribe to previously subscribed topics.") - subscribed_topics = self._subscribed_topics.copy() self._subscribed_topics = [] while subscribed_topics: feed = subscribed_topics.pop() diff --git a/tests/test_reconnect.py b/tests/test_reconnect.py new file mode 100644 index 00000000..6b049246 --- /dev/null +++ b/tests/test_reconnect.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: 2025 VladimĂ­r Kotal +# +# SPDX-License-Identifier: Unlicense + +"""reconnect tests""" + +import logging +import ssl +import sys + +import pytest +from mocket import Mocket + +import adafruit_minimqtt.adafruit_minimqtt as MQTT + +if not sys.implementation.name == "circuitpython": + from typing import Optional + + from circuitpython_typing.socket import ( + SocketType, + SSLContextType, + ) + + +class FakeConnectionManager: + """ + Fake ConnectionManager class + """ + + def __init__(self, socket): + self._socket = socket + + def get_socket( # noqa: PLR0913, Too many arguments + self, + host: str, + port: int, + proto: str, + session_id: Optional[str] = None, + *, + timeout: float = 1.0, + is_ssl: bool = False, + ssl_context: Optional[SSLContextType] = None, + ) -> SocketType: + """ + Return the specified socket. + """ + return self._socket + + def close_socket(self, socket) -> None: + pass + + +def handle_subscribe(client, user_data, topic, qos): + """ + Record topics into user data. + """ + assert topic + assert user_data["topics"] is not None + assert qos == 0 + + user_data["topics"].append(topic) + + +def handle_disconnect(client, user_data, zero): + """ + Record disconnect. + """ + + user_data["disconnect"] = True + + +# The MQTT packet contents below were captured using Mosquitto client+server. +testdata = [ + ( + [], + bytearray( + [ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x02, + 0x00, + ] + ), + ), + ( + [("foo/bar", 0)], + bytearray( + [ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x02, + 0x00, + ] + ), + ), + ( + [("foo/bar", 0), ("bah", 0)], + bytearray( + [ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + 0x00, + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x02, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x03, + 0x00, + ] + ), + ), +] + + +@pytest.mark.parametrize( + "topics,to_send", + testdata, + ids=[ + "no_topic", + "single_topic", + "multi_topic", + ], +) +def test_reconnect(topics, to_send) -> None: + """ + Test reconnect() handling, mainly that it performs disconnect on already connected socket. + + Nothing will travel over the wire, it is all fake. + """ + logging.basicConfig() + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + host = "localhost" + port = 1883 + + user_data = {"topics": [], "disconnect": False} + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + ssl_context=ssl.create_default_context(), + connect_retries=1, + user_data=user_data, + ) + + mocket = Mocket(to_send) + mqtt_client._connection_manager = FakeConnectionManager(mocket) + mqtt_client.connect() + + mqtt_client.logger = logger + + if topics: + logger.info(f"subscribing to {topics}") + mqtt_client.subscribe(topics) + + logger.info("reconnecting") + mqtt_client.on_subscribe = handle_subscribe + mqtt_client.on_disconnect = handle_disconnect + mqtt_client.reconnect() + + assert user_data.get("disconnect") == True + assert set(user_data.get("topics")) == set([t[0] for t in topics]) From 6ebbb2601bf934ab0d376c7deaca9f3ce524bf04 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Mon, 10 Feb 2025 08:02:21 +0100 Subject: [PATCH 4/8] check close count --- tests/test_reconnect.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_reconnect.py b/tests/test_reconnect.py index 6b049246..5e0a5331 100644 --- a/tests/test_reconnect.py +++ b/tests/test_reconnect.py @@ -29,6 +29,7 @@ class FakeConnectionManager: def __init__(self, socket): self._socket = socket + self.close_cnt = 0 def get_socket( # noqa: PLR0913, Too many arguments self, @@ -47,7 +48,7 @@ def get_socket( # noqa: PLR0913, Too many arguments return self._socket def close_socket(self, socket) -> None: - pass + self.close_cnt += 1 def handle_subscribe(client, user_data, topic, qos): @@ -202,4 +203,5 @@ def test_reconnect(topics, to_send) -> None: mqtt_client.reconnect() assert user_data.get("disconnect") == True + assert mqtt_client._connection_manager.close_cnt == 1 assert set(user_data.get("topics")) == set([t[0] for t in topics]) From 8f4c421c69aa412f64176c3d2034e56db0b88fd2 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Mon, 10 Feb 2025 08:10:57 +0100 Subject: [PATCH 5/8] add test for the not connected case --- tests/test_reconnect.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_reconnect.py b/tests/test_reconnect.py index 5e0a5331..bd9d3926 100644 --- a/tests/test_reconnect.py +++ b/tests/test_reconnect.py @@ -205,3 +205,43 @@ def test_reconnect(topics, to_send) -> None: assert user_data.get("disconnect") == True assert mqtt_client._connection_manager.close_cnt == 1 assert set(user_data.get("topics")) == set([t[0] for t in topics]) + + +def test_reconnect_not_connected() -> None: + """ + Test reconnect() handling not connected. + """ + logging.basicConfig() + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + host = "localhost" + port = 1883 + + user_data = {"topics": [], "disconnect": False} + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + ssl_context=ssl.create_default_context(), + connect_retries=1, + user_data=user_data, + ) + + mocket = Mocket( + bytearray( + [ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + ] + ) + ) + mqtt_client._connection_manager = FakeConnectionManager(mocket) + + mqtt_client.logger = logger + mqtt_client.on_disconnect = handle_disconnect + mqtt_client.reconnect() + + assert user_data.get("disconnect") == False + assert mqtt_client._connection_manager.close_cnt == 0 From 933b1cb66daa7b2d0e1995d329eeebea12052bd5 Mon Sep 17 00:00:00 2001 From: Vladimir Kotal Date: Mon, 10 Feb 2025 18:02:41 +0100 Subject: [PATCH 6/8] preserve sesion ID on reconnect --- adafruit_minimqtt/adafruit_minimqtt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 1c63d5ca..a8d53428 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -204,6 +204,8 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments if port: self.port = port + self.session_id = None + # define client identifier if client_id: # user-defined client_id MAY allow client_id's > 23 bytes or @@ -528,6 +530,7 @@ def _connect( # noqa: PLR0912, PLR0913, PLR0915, Too many branches, Too many ar is_ssl=self._is_ssl, ssl_context=self._ssl_context, ) + self.session_id = session_id self._backwards_compatible_sock = not hasattr(self._sock, "recv_into") fixed_header = bytearray([0x10]) @@ -946,7 +949,7 @@ def reconnect(self, resub_topics: bool = True) -> int: subscribed_topics = self._subscribed_topics.copy() self.disconnect() - ret = self.connect() + ret = self.connect(session_id=self.session_id) self.logger.debug("Reconnected with broker") if resub_topics and subscribed_topics: From 40a4fa91268f871aaae23354748177a57873ab56 Mon Sep 17 00:00:00 2001 From: Tyeth Gundry Date: Wed, 12 Mar 2025 19:04:03 +0000 Subject: [PATCH 7/8] ruff format reconnect tests --- tests/test_reconnect.py | 152 +++++++++++++++++++--------------------- 1 file changed, 72 insertions(+), 80 deletions(-) diff --git a/tests/test_reconnect.py b/tests/test_reconnect.py index bd9d3926..52b8c76f 100644 --- a/tests/test_reconnect.py +++ b/tests/test_reconnect.py @@ -74,84 +74,78 @@ def handle_disconnect(client, user_data, zero): testdata = [ ( [], - bytearray( - [ - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x01, - 0x00, - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x02, - 0x00, - ] - ), + bytearray([ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x02, + 0x00, + ]), ), ( [("foo/bar", 0)], - bytearray( - [ - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x01, - 0x00, - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x02, - 0x00, - ] - ), + bytearray([ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x02, + 0x00, + ]), ), ( [("foo/bar", 0), ("bah", 0)], - bytearray( - [ - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x01, - 0x00, - 0x00, - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x02, - 0x00, - 0x90, # SUBACK - 0x03, - 0x00, - 0x03, - 0x00, - ] - ), + bytearray([ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x01, + 0x00, + 0x00, + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x02, + 0x00, + 0x90, # SUBACK + 0x03, + 0x00, + 0x03, + 0x00, + ]), ), ] @@ -228,14 +222,12 @@ def test_reconnect_not_connected() -> None: ) mocket = Mocket( - bytearray( - [ - 0x20, # CONNACK - 0x02, - 0x00, - 0x00, - ] - ) + bytearray([ + 0x20, # CONNACK + 0x02, + 0x00, + 0x00, + ]) ) mqtt_client._connection_manager = FakeConnectionManager(mocket) From a0f49a46a1e09de14df5e86769b5376aee6b3442 Mon Sep 17 00:00:00 2001 From: Neradoc Date: Sun, 13 Apr 2025 21:45:52 +0200 Subject: [PATCH 8/8] fix missing parameter in MMQTTException --- adafruit_minimqtt/adafruit_minimqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 953718ba..2a1d6000 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -1038,7 +1038,7 @@ def _wait_for_msg( # noqa: PLR0912, Too many branches if error.errno in (errno.ETIMEDOUT, errno.EAGAIN): # raised by a socket timeout if 0 bytes were present return None - raise MMQTTException from error + raise MMQTTException("Unexpected error while waiting for messages") from error if res in [None, b""]: # If we get here, it means that there is nothing to be received