From 5dc0f0a05d5532c987038768e034953cbd3a71ce Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Fri, 5 Sep 2025 00:27:53 +0200 Subject: [PATCH 1/6] Improve typing in v2 QuirkBuilder (#1660) --- zigpy/quirks/v2/__init__.py | 55 +++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/zigpy/quirks/v2/__init__.py b/zigpy/quirks/v2/__init__.py index e680022f7..0f144c38b 100644 --- a/zigpy/quirks/v2/__init__.py +++ b/zigpy/quirks/v2/__init__.py @@ -14,6 +14,7 @@ import attrs from frozendict import frozendict +from typing_extensions import Self from zigpy.const import ( SIG_ENDPOINTS, @@ -632,7 +633,7 @@ def __init__( UNBUILT_QUIRK_BUILDERS.append(self) - def _add_entity_metadata(self, entity_metadata: EntityMetadata) -> QuirkBuilder: + def _add_entity_metadata(self, entity_metadata: EntityMetadata) -> Self: """Register new entity metadata and validate config.""" if entity_metadata.primary and any( entity.primary for entity in self.entity_metadata @@ -642,7 +643,7 @@ def _add_entity_metadata(self, entity_metadata: EntityMetadata) -> QuirkBuilder: self.entity_metadata.append(entity_metadata) return self - def applies_to(self, manufacturer: str, model: str) -> QuirkBuilder: + def applies_to(self, manufacturer: str, model: str) -> Self: """Register this quirks v2 entry for the specified manufacturer and model.""" self.manufacturer_model_metadata.append( ManufacturerModelMetadata(manufacturer=manufacturer, model=model) @@ -652,7 +653,7 @@ def applies_to(self, manufacturer: str, model: str) -> QuirkBuilder: # backward compatibility also_applies_to = applies_to - def filter(self, filter_function: FilterType) -> QuirkBuilder: + def filter(self, filter_function: FilterType) -> Self: """Add a filter and returns self. The filter function should take a single argument, a zigpy.device.Device @@ -669,7 +670,7 @@ def firmware_version_filter( min_version: int | None = None, max_version: int | None = None, allow_missing: bool = True, - ) -> QuirkBuilder: + ) -> Self: """Add a firmware version filter and returns self. The min_version and max_version are integers representing the firmware version, @@ -683,7 +684,7 @@ def firmware_version_filter( ) return self - def device_class(self, custom_device_class: type[CustomDeviceV2]) -> QuirkBuilder: + def device_class(self, custom_device_class: type[CustomDeviceV2]) -> Self: """Set the custom device class to be used in this quirk and returns self. The custom device class must be a subclass of CustomDeviceV2. @@ -694,7 +695,7 @@ def device_class(self, custom_device_class: type[CustomDeviceV2]) -> QuirkBuilde self.custom_device_class = custom_device_class return self - def node_descriptor(self, node_descriptor: NodeDescriptor) -> QuirkBuilder: + def node_descriptor(self, node_descriptor: NodeDescriptor) -> Self: """Set the node descriptor and returns self. The node descriptor must be a NodeDescriptor instance and it will be used @@ -703,7 +704,7 @@ def node_descriptor(self, node_descriptor: NodeDescriptor) -> QuirkBuilder: self.device_node_descriptor = node_descriptor.freeze() return self - def skip_configuration(self, skip_configuration: bool = True) -> QuirkBuilder: + def skip_configuration(self, skip_configuration: bool = True) -> Self: """Set the skip_configuration and returns self. If skip_configuration is True, reporting configuration will not be @@ -718,7 +719,7 @@ def adds( cluster_type: ClusterType = ClusterType.Server, endpoint_id: int = 1, constant_attributes: dict[ZCLAttributeDef, Any] | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an AddsMetadata entry and returns self. This method allows adding a cluster to a device when the quirk is applied. @@ -745,7 +746,7 @@ def removes( cluster_id: int, cluster_type: ClusterType = ClusterType.Server, endpoint_id: int = 1, - ) -> QuirkBuilder: + ) -> Self: """Add a RemovesMetadata entry and returns self. This method allows removing a cluster from a device when the quirk is applied. @@ -764,7 +765,7 @@ def replaces( cluster_id: int | None = None, cluster_type: ClusterType = ClusterType.Server, endpoint_id: int = 1, - ) -> QuirkBuilder: + ) -> Self: """Add a ReplacesMetadata entry and returns self. This method allows replacing a cluster on a device when the quirk is applied. @@ -797,7 +798,7 @@ def replace_cluster_occurrences( replacement_cluster_class: type[Cluster | CustomCluster], replace_server_instances: bool = True, replace_client_instances: bool = True, - ) -> QuirkBuilder: + ) -> Self: """Add a ReplaceClusterOccurrencesMetadata entry and returns self. This method allows replacing a cluster on a device across all endpoints @@ -829,7 +830,7 @@ def adds_endpoint( endpoint_id: int, profile_id: int = zigpy.profiles.zha.PROFILE_ID, device_type: int = 0xFF, - ) -> QuirkBuilder: + ) -> Self: """Add an AddsEndpointMetadata entry and return self.""" add = AddsEndpointMetadata( endpoint_id=endpoint_id, profile_id=profile_id, device_type=device_type @@ -837,7 +838,7 @@ def adds_endpoint( self.adds_endpoint_metadata.append(add) return self - def removes_endpoint(self, endpoint_id: int) -> QuirkBuilder: + def removes_endpoint(self, endpoint_id: int) -> Self: """Add a RemovesEndpointMetadata entry and return self.""" remove = RemovesEndpointMetadata(endpoint_id=endpoint_id) self.removes_endpoint_metadata.append(remove) @@ -848,7 +849,7 @@ def replaces_endpoint( endpoint_id: int, profile_id: int = zigpy.profiles.zha.PROFILE_ID, device_type: int = 0xFF, - ) -> QuirkBuilder: + ) -> Self: """Add a ReplacesEndpointMetadata entry and return self.""" replace = ReplacesEndpointMetadata( endpoint_id=endpoint_id, profile_id=profile_id, device_type=device_type @@ -872,7 +873,7 @@ def enum( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing ZCLEnumMetadata and return self. This method allows exposing an enum based entity in Home Assistant. @@ -918,7 +919,7 @@ def sensor( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing ZCLSensorMetadata and return self. This method allows exposing a sensor entity in Home Assistant. @@ -968,7 +969,7 @@ def switch( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing SwitchMetadata and return self. This method allows exposing a switch entity in Home Assistant. @@ -1017,7 +1018,7 @@ def number( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing NumberMetadata and return self. This method allows exposing a number entity in Home Assistant. @@ -1064,7 +1065,7 @@ def binary_sensor( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing BinarySensorMetadata and return self. This method allows exposing a binary sensor entity in Home Assistant. @@ -1104,7 +1105,7 @@ def write_attr_button( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing WriteAttributeButtonMetadata and return self. This method allows exposing a button entity in Home Assistant that writes @@ -1143,7 +1144,7 @@ def command_button( translation_key: str | None = None, fallback_name: str | None = None, primary: bool | None = None, - ) -> QuirkBuilder: + ) -> Self: """Add an EntityMetadata containing ZCLCommandButtonMetadata and return self. This method allows exposing a button entity in Home Assistant that executes @@ -1170,19 +1171,19 @@ def command_button( def device_automation_triggers( self, device_automation_triggers: dict[tuple[str, str], dict[str, str]] - ) -> QuirkBuilder: + ) -> Self: """Add device automation triggers and returns self.""" self.device_automation_triggers_metadata.update(device_automation_triggers) return self - def friendly_name(self, *, model: str, manufacturer: str) -> QuirkBuilder: + def friendly_name(self, *, model: str, manufacturer: str) -> Self: """Renames the device.""" self.friendly_name_metadata = FriendlyNameMetadata( model=model, manufacturer=manufacturer ) return self - def device_alert(self, *, level: DeviceAlertLevel, message: str) -> QuirkBuilder: + def device_alert(self, *, level: DeviceAlertLevel, message: str) -> Self: """Adds a device alert.""" self.device_alerts.append(DeviceAlertMetadata(level=level, message=message)) return self @@ -1195,7 +1196,7 @@ def prevent_default_entity_creation( cluster_type: ClusterType | None = None, unique_id_suffix: str | None = None, function: Callable[[Any], bool] | None = None, - ) -> QuirkBuilder: + ) -> Self: """Do not create default entities.""" if cluster_id is not None and cluster_type is None: cluster_type = ClusterType.Server @@ -1229,7 +1230,7 @@ def change_entity_metadata( new_entity_category: EntityType | None = None, new_entity_registry_enabled_default: bool | None = None, new_fallback_name: str | None = None, - ) -> QuirkBuilder: + ) -> Self: """Change entity metadata for matching entities.""" if cluster_id is not None and cluster_type is None: cluster_type = ClusterType.Server @@ -1294,7 +1295,7 @@ def add_to_registry(self) -> QuirksV2RegistryEntry: return quirk - def clone(self, omit_man_model_data=True) -> QuirkBuilder: + def clone(self, omit_man_model_data=True) -> Self: """Clone this QuirkBuilder potentially omitting manufacturer and model data.""" new_builder = deepcopy(self) new_builder.registry = self.registry From 88f0ae110a29500cf46940eb715c37d773276e01 Mon Sep 17 00:00:00 2001 From: janeswingler <165973604+janeswingler@users.noreply.github.com> Date: Mon, 8 Sep 2025 13:17:25 -0700 Subject: [PATCH 2/6] Add packet reception API (#1643) * Add register_packet_callback() and self._packet_callbacks * Add notify_packet_callbacks() * Add generic packet reception API - Add register_packet_callback() method for registering packet callbacks - Add notify_packet_callbacks() to dispatch packets to registered callbacks - Support both global callbacks (filter=None) and address-specific filtering - Add _packet_callbacks storage using defaultdict for callback management * Add generic packet reception API This implements a generic packet reception API as referenced in issue #1468. It allows for registration of callbacks for raw ZigbeePacket objects, and enables future features like TouchLink support via inter-PAN communication. - Add register_packet_callback() method for raw packet callbacks - Add notify_packet_callbacks() to dispatch packets to callbacks - Add _packet_callbacks storage using defaultdict Co-authored-by: Justin Saminathen Co-authored-by: Aaron Mundanilkunathil * Address review feedback for packet reception API Co-authored-by: Jane Swingler Co-authored-by: Aaron Mundanilkunathil * Added unit tests for packet reception API Co-authored-by: Jane Swingler Co-authored-by: Aaron Mundanilkunathil * Replace make_packet with base_packet fixture utilizing .replace(), and use mock_calls assertions Co-authored-by: Jane Swingler Co-authored-by: Aaron Mundanilkunathil * Fix pre-commit formatting issues Co-authored-by: Justin Saminathen Co-authored-by: Aaron Mundanilkunathil * tests: add address + global exception callback tests for packet reception API Co-authored-by: Justin Saminathen Co-authored-by: Aaron Mundanilkunathil --------- Co-authored-by: Justin Saminathen Co-authored-by: Aaron Mundanilkunathil Co-authored-by: A-Mundanilkunathil <119900566+A-Mundanilkunathil@users.noreply.github.com> --- .gitignore | 5 +- tests/test_packet_callbacks.py | 116 +++++++++++++++++++++++++++++++++ zigpy/application.py | 53 +++++++++++++++ 3 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 tests/test_packet_callbacks.py diff --git a/.gitignore b/.gitignore index b0f484a7e..d547c08f1 100644 --- a/.gitignore +++ b/.gitignore @@ -75,4 +75,7 @@ ENV/ .DS_Store # Don't keep track of downloaded OTA files -tests/ota/files/external/dl \ No newline at end of file +tests/ota/files/external/dl + +# VS Code workspace files +*.code-workspace \ No newline at end of file diff --git a/tests/test_packet_callbacks.py b/tests/test_packet_callbacks.py new file mode 100644 index 000000000..80c034158 --- /dev/null +++ b/tests/test_packet_callbacks.py @@ -0,0 +1,116 @@ +import pytest + +import zigpy.types as t + +from .async_mock import MagicMock, call +from .conftest import make_ieee + + +@pytest.fixture +def base_packet(): + """Base ZigbeePacket to clone with .replace().""" + return t.ZigbeePacket( + src=t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x1234), + src_ep=1, + dst=t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x0000), + dst_ep=1, + tsn=123, + profile_id=0x0104, + cluster_id=0x0006, + data=t.SerializableBytes(b"test"), + lqi=255, + rssi=-30, + ) + + +@pytest.mark.parametrize( + "filter_address", + [ + None, + t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x1234), + t.AddrModeAddress(addr_mode=t.AddrMode.IEEE, address=make_ieee()), + ], +) +async def test_packet_callback_register_cancel(app, filter_address): + """Register and cancel packet callback.""" + cb = MagicMock() + cancel = app.register_packet_callback(filter_address, cb) + assert cb in app._packet_callbacks[filter_address] + cancel() + assert cb not in app._packet_callbacks[filter_address] + cancel() + + +@pytest.mark.parametrize( + ("src_address", "should_trigger"), + [ + (t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x1234), True), + (t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x5678), False), + ], +) +async def test_packet_callback_address_filter( + app, base_packet, src_address, should_trigger +): + """Source address match.""" + filt = t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x1234) + cb = MagicMock() + app.register_packet_callback(filt, cb) + pkt = base_packet.replace(src=src_address) + app.notify_packet_callbacks(pkt) + if should_trigger: + assert cb.mock_calls == [call(pkt)] + else: + assert cb.mock_calls == [] + + +async def test_packet_callback_addr_mode_mismatch(app, base_packet): + """Address mode mismatch.""" + ieee = make_ieee() + filt = t.AddrModeAddress(addr_mode=t.AddrMode.IEEE, address=ieee) + cb = MagicMock() + app.register_packet_callback(filt, cb) + app.notify_packet_callbacks(base_packet) + assert cb.mock_calls == [] + + +async def test_packet_callback_multiple_same_filter(app, base_packet): + """Multiple callbacks, cancel one.""" + addr = t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=0x9ABC) + cb1 = MagicMock() + cb2 = MagicMock() + cancel1 = app.register_packet_callback(addr, cb1) + app.register_packet_callback(addr, cb2) + pkt1 = base_packet.replace(src=addr, tsn=200) + app.notify_packet_callbacks(pkt1) + assert cb1.mock_calls == [call(pkt1)] + assert cb2.mock_calls == [call(pkt1)] + cancel1() + pkt2 = base_packet.replace(src=addr, tsn=201) + app.notify_packet_callbacks(pkt2) + assert cb1.mock_calls == [call(pkt1)] + assert cb2.mock_calls == [call(pkt1), call(pkt2)] + + +async def test_packet_callback_exception_global(app, base_packet, caplog): + """Exception isolation for global callbacks.""" + failing = MagicMock(side_effect=ValueError("boom")) + ok = MagicMock() + app.register_packet_callback(None, failing) + app.register_packet_callback(None, ok) + app.notify_packet_callbacks(base_packet) + assert ok.mock_calls == [call(base_packet)] + assert any("global packet callback" in r.message.lower() for r in caplog.records) + + +async def test_packet_callback_exception_address(app, base_packet, caplog): + """Exception isolation for address-specific callbacks.""" + addr = base_packet.src + failing = MagicMock(side_effect=ValueError("boom")) + ok = MagicMock() + app.register_packet_callback(addr, failing) + app.register_packet_callback(addr, ok) + app.notify_packet_callbacks(base_packet) + assert ok.mock_calls == [call(base_packet)] + assert any( + "packet callback for address" in r.message.lower() for r in caplog.records + ) diff --git a/zigpy/application.py b/zigpy/application.py index 1783940c3..bf3ffb4d3 100644 --- a/zigpy/application.py +++ b/zigpy/application.py @@ -92,6 +92,11 @@ def __init__(self, config: dict) -> None: collections.deque[zigpy.listeners.BaseRequestListener], ] = collections.defaultdict(lambda: collections.deque([])) + # Add callback storage + self._packet_callbacks: collections.defaultdict[ + t.AddrModeAddress | None, list[typing.Callable[[t.ZigbeePacket], None]] + ] = collections.defaultdict(list) + # Context variable for request priority context manager self._packet_priority_var = contextvars.ContextVar( "request_priority", default=t.PacketPriority.NORMAL @@ -1194,6 +1199,54 @@ def get_device_with_address( else: raise ValueError(f"Invalid address: {address!r}") + def register_packet_callback( + self, + filter: t.AddrModeAddress | None, + callback: typing.Callable[[t.ZigbeePacket], None], + ) -> typing.Callable[[], None]: + """Register a callback that is called when a Zigbee packet is received. + + Args: + ---- + filter: Optional address filter. If None, callback receives all packets. + If provided, only packets from this source address trigger the callback. + callback: Function to call when a matching packet is received. + + Returns: + ------- + A callable that can be used to unregister the callback. + + """ + self._packet_callbacks[filter].append(callback) + + def cancel_callback() -> None: + """Remove the callback.""" + with contextlib.suppress(ValueError): + self._packet_callbacks[filter].remove(callback) + + return cancel_callback + + def notify_packet_callbacks(self, packet: t.ZigbeePacket) -> None: + """Notify registered packet callbacks about a received Zigbee packet.""" + + # Notify global callbacks (registered with None filter) + for callback in self._packet_callbacks[None]: + try: + callback(packet) + except Exception: + LOGGER.exception("Error in global packet callback: %s", callback) + + # Notify address-specific callbacks + for callback in self._packet_callbacks[packet.src]: + try: + callback(packet) + except Exception: + LOGGER.exception( + "Error in packet callback for address %s: %s", + packet.src, + callback, + ) + def register_callback_listener( self, src: zigpy.device.Device | zigpy.listeners.ANY_DEVICE, From 5d5ce82f29b08b0db330e865ad20ec8791a8eefc Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 9 Sep 2025 18:11:21 -0400 Subject: [PATCH 3/6] Add `can_write_network_settings` to ControllerApplication API (#1616) * Implement can_write_network_settings * Allow passing `None` * Revert "Allow passing `None`" This reverts commit a989c7d94f4d8abc97a9b882b8f2e0fa7aaa4293. * Add `DestructiveWriteNetworkSettings` * Enforce kwargs * Add `can_read_network_settings` * Oops * Revert "Add `can_read_network_settings`" This reverts commit 666915b114f3c1c0e5b0afa264aac61d17439ce6. --- tests/test_application.py | 7 +++++++ zigpy/application.py | 20 ++++++++++++++++++++ zigpy/exceptions.py | 8 ++++++++ 3 files changed, 35 insertions(+) diff --git a/tests/test_application.py b/tests/test_application.py index 3330199bd..e53f95ade 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1689,3 +1689,10 @@ async def task_with_priority(name: str, priority: int): call(packet.replace(data=b"high")), call(packet.replace(data=b"normal")), ] + + +async def test_can_write_network_settings(app) -> None: + # The default is True + assert await app.can_write_network_settings( + network_info=app.state.network_info, node_info=app.state.node_info + ) diff --git a/zigpy/application.py b/zigpy/application.py index bf3ffb4d3..1835816d1 100644 --- a/zigpy/application.py +++ b/zigpy/application.py @@ -1344,6 +1344,26 @@ async def permit_with_link_key( """Permit a node to join with the provided link key.""" raise NotImplementedError # pragma: no cover + async def can_write_network_settings( + self, + *, + network_info: zigpy.state.NetworkInfo, + node_info: zigpy.state.NodeInfo, + ) -> bool: + """Returns `True` if the radio can write the given network settings. + + If restoration is not possible, `CannotWriteNetworkSettings` is raised. + If restoration is possible in a destructive way (e.g. write-once tokens), + `DestructiveWriteNetworkSettings` is raised. + + Some radio firmwares do not support writing every network setting in a backup + (ZiGate cannot set the PAN ID, older EZSP can only write the EUI64 once, etc.). + Not all situations, however, are critical failures: if we are restoring a + backup where the PAN ID does not change or the EUI64 remains the same, we can + restore the backup successfully. + """ + return True + @abc.abstractmethod async def write_network_info( self, diff --git a/zigpy/exceptions.py b/zigpy/exceptions.py index b7b85562d..2906a62da 100644 --- a/zigpy/exceptions.py +++ b/zigpy/exceptions.py @@ -18,6 +18,14 @@ class ControllerException(ZigbeeException): """Application controller failed in some way.""" +class CannotWriteNetworkSettings(ZigbeeException): + """The provided network settings cannot be written due to a radio limitation.""" + + +class DestructiveWriteNetworkSettings(ZigbeeException): + """The provided network settings will be written but in a destructive manner.""" + + class APIException(ZigbeeException): """Radio API failed in some way.""" From 5b3472597d1ca5d081043752ef09873afd006203 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 9 Sep 2025 18:11:31 -0400 Subject: [PATCH 4/6] Load devices by default when performing a backup (#1666) --- zigpy/backups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zigpy/backups.py b/zigpy/backups.py index 0cc95a8a0..8036e1b90 100644 --- a/zigpy/backups.py +++ b/zigpy/backups.py @@ -139,7 +139,7 @@ def from_network_state(self) -> NetworkBackup: node_info=self.app.state.node_info, ) - async def create_backup(self, *, load_devices: bool = False) -> NetworkBackup: + async def create_backup(self, *, load_devices: bool = True) -> NetworkBackup: await self.app.load_network_info(load_devices=load_devices) backup = self.from_network_state() From 89bf7509e290010235f9528b57b176a965faf6d7 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 9 Sep 2025 18:11:44 -0400 Subject: [PATCH 5/6] Implement just `EventBase` class (#1665) * Port over ZHA event module * Move it into a submodule * Reimplement counters as an eventable object * Remove unnecessary tests * Rename submodule * Support Python 3.9 * Python 3.9 compatibility * Sleeps for 3.9 * Revert "Sleeps for 3.9" This reverts commit 1506e98fb54d7f6634c0d6a54a862391a7ec8b59. * Fix underlying 3.9 issue in unit tests * Add some tests * Add diagnostics dicts to devices and applications * Revert API additions, just keep the `event` module --- tests/test_event.py | 214 ++++++++++++++++++++++++++++++++++++++ zigpy/event/__init__.py | 3 + zigpy/event/event_base.py | 130 +++++++++++++++++++++++ 3 files changed, 347 insertions(+) create mode 100644 tests/test_event.py create mode 100644 zigpy/event/__init__.py create mode 100644 zigpy/event/event_base.py diff --git a/tests/test_event.py b/tests/test_event.py new file mode 100644 index 000000000..173e8c2d7 --- /dev/null +++ b/tests/test_event.py @@ -0,0 +1,214 @@ +"""Event tests.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, call + +import pytest + +from zigpy.event import EventBase +from zigpy.event.event_base import EventListener + + +class EventGenerator(EventBase): + """Event generator for testing.""" + + +class Event: + """Event class for testing.""" + + event = "test" + event_type = "testing" + + +def test_event_base_unsubs(): + """Test event base class.""" + event = EventGenerator() + assert not event._listeners + assert not event._global_listeners + + callback = MagicMock() + + unsub = event.on_event("test", callback) + assert event._listeners == { + "test": [EventListener(callback=callback, with_context=False)] + } + unsub() + assert event._listeners == {"test": []} + + unsub = event.on_all_events(callback) + assert event._global_listeners == [ + EventListener(callback=callback, with_context=False) + ] + unsub() + assert not event._global_listeners + + unsub = event.once("test", callback) + assert "test" in event._listeners + assert len(event._listeners["test"]) == 1 + unsub() + assert event._listeners == {"test": []} + + +def test_event_base_emit(): + """Test event base class.""" + event = EventGenerator() + assert not event._listeners + assert not event._global_listeners + + callback = MagicMock() + + event.once("test", callback) + event.emit("test") + assert callback.called + + callback.reset_mock() + event.emit("test") + assert not callback.called + + unsub = event.on_event("test", callback) + event.emit("test") + assert callback.called + unsub() + + callback.reset_mock() + unsub = event.on_all_events(callback) + event.emit("test") + assert callback.called + unsub() + + assert "test" in event._listeners + assert event._listeners == {"test": []} + assert not event._global_listeners + + +def test_event_base_emit_data(): + """Test event base class.""" + event = EventGenerator() + assert not event._listeners + assert not event._global_listeners + + callback = MagicMock() + + event.once("test", callback) + event.emit("test", "data") + assert callback.called + assert callback.call_args[0] == ("data",) + + callback.reset_mock() + event.emit("test", "data") + assert not callback.called + + unsub = event.on_event("test", callback) + event.emit("test", "data") + assert callback.called + assert callback.call_args[0] == ("data",) + unsub() + + callback.reset_mock() + unsub = event.on_all_events(callback) + event.emit("test", "data") + assert callback.called + assert callback.call_args[0] == ("data",) + unsub() + + assert "test" in event._listeners + assert event._listeners == {"test": []} + assert not event._global_listeners + + +async def test_event_base_emit_coro(): + """Test event base class.""" + event = EventGenerator() + assert not event._listeners + assert not event._global_listeners + + callback = AsyncMock() + + event.once("test", callback) + event.emit("test", "data") + + await asyncio.gather(*event._event_tasks) + + assert callback.await_count == 1 + assert callback.mock_calls == [call("data")] + assert not event._event_tasks + + callback.reset_mock() + + unsub = event.on_event("test", callback) + event.emit("test", "data") + + await asyncio.gather(*event._event_tasks) + + assert callback.await_count == 1 + assert callback.mock_calls == [call("data")] + unsub() + assert not event._event_tasks + + callback.reset_mock() + + unsub = event.on_all_events(callback) + event.emit("test", "data") + + await asyncio.gather(*event._event_tasks) + + assert callback.await_count == 1 + assert callback.mock_calls == [call("data")] + unsub() + assert not event._event_tasks + + test_event = Event() + event.on_event(test_event.event, event._handle_event_protocol) + event.handle_test = AsyncMock() + + event.emit(test_event.event, test_event) + + await asyncio.gather(*event._event_tasks) + + assert event.handle_test.await_count == 1 + assert event.handle_test.mock_calls == [call(test_event)] + assert not event._event_tasks + + +async def test_event_emit_with_context(): + """Test event emitting with context.""" + + event = EventGenerator() + async_callback = AsyncMock() + sync_callback = MagicMock() + + event.once("test", sync_callback, with_context=True) + event.once("test", async_callback, with_context=True) + event.emit("test", "data") + + await asyncio.gather(*event._event_tasks) + + sync_callback.assert_called_once_with("test", "data") + async_callback.assert_awaited_once_with("test", "data") + + +def test_handle_event_protocol(): + """Test event base class.""" + + event_handler = EventGenerator() + event_handler.handle_test = MagicMock() + event_handler.on_event("test", event_handler._handle_event_protocol) + + event = Event() + event_handler.emit(event.event, event) + + assert event_handler.handle_test.called + assert event_handler.handle_test.call_args[0] == (event,) + + +def test_handle_event_protocol_no_event(caplog: pytest.LogCaptureFixture): + """Test event base class.""" + + event_handler = EventGenerator() + event_handler.on_event("not_test", event_handler._handle_event_protocol) + event = Event() + event_handler.emit("not_test", event) + + assert "Received unknown event:" in caplog.text diff --git a/zigpy/event/__init__.py b/zigpy/event/__init__.py new file mode 100644 index 000000000..c0f38f300 --- /dev/null +++ b/zigpy/event/__init__.py @@ -0,0 +1,3 @@ +from .event_base import EventBase + +__all__ = ["EventBase"] diff --git a/zigpy/event/event_base.py b/zigpy/event/event_base.py new file mode 100644 index 000000000..83d93ee78 --- /dev/null +++ b/zigpy/event/event_base.py @@ -0,0 +1,130 @@ +"""Provide Event base classes for zigpy.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +import dataclasses +import logging +import sys +from typing import Any + +if sys.version_info >= (3, 10): + from inspect import iscoroutinefunction +else: + # https://github.com/python/cpython/issues/84753 + from asyncio import iscoroutinefunction + +_LOGGER = logging.getLogger(__package__) + + +@dataclasses.dataclass( + frozen=True, **({"slots": True} if sys.version_info >= (3, 10) else {}) +) +class EventListener: + """Listener for an event.""" + + callback: Callable + with_context: bool + + +class EventBase: + """Base class for event handling and emitting objects.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize event base.""" + super().__init__(*args, **kwargs) + self._listeners: dict[str, list[EventListener]] = {} + self._event_tasks: list[asyncio.Task] = [] + self._global_listeners: list[EventListener] = [] + + def on_event( # pylint: disable=invalid-name + self, event_name: str, callback: Callable, with_context: bool = False + ) -> Callable: + """Register an event callback.""" + listener = EventListener(callback=callback, with_context=with_context) + + listeners: list = self._listeners.setdefault(event_name, []) + listeners.append(listener) + + def unsubscribe() -> None: + """Unsubscribe listeners.""" + if listener in listeners: + listeners.remove(listener) + + return unsubscribe + + def on_all_events( # pylint: disable=invalid-name + self, callback: Callable, with_context: bool = False + ) -> Callable: + """Register a callback for all events.""" + listener = EventListener(callback=callback, with_context=with_context) + self._global_listeners.append(listener) + + def unsubscribe() -> None: + """Unsubscribe listeners.""" + if listener in self._global_listeners: + self._global_listeners.remove(listener) + + return unsubscribe + + def once( + self, event_name: str, callback: Callable, with_context: bool = False + ) -> Callable: + """Listen for an event exactly once.""" + if iscoroutinefunction(callback): + + async def async_event_listener(*args, **kwargs) -> None: + unsub() + task = asyncio.create_task(callback(*args, **kwargs)) + self._event_tasks.append(task) + task.add_done_callback(self._event_tasks.remove) + + unsub = self.on_event( + event_name, async_event_listener, with_context=with_context + ) + return unsub # noqa: RET504 + + def event_listener(*args, **kwargs) -> None: + unsub() + callback(*args, **kwargs) + + unsub = self.on_event(event_name, event_listener, with_context=with_context) + return unsub # noqa: RET504 + + def emit(self, event_name: str, data=None) -> None: + """Run all callbacks for an event.""" + listeners = [*self._listeners.get(event_name, []), *self._global_listeners] + _LOGGER.debug( + "Emitting event %s with data %r (%d listeners)", + event_name, + data, + len(listeners), + ) + + for listener in listeners: + if listener.with_context: + call = listener.callback(event_name, data) + else: + call = listener.callback(data) + + if iscoroutinefunction(listener.callback): + task = asyncio.create_task(call) + self._event_tasks.append(task) + task.add_done_callback(self._event_tasks.remove) + + def _handle_event_protocol(self, event) -> None: + """Process an event based on event protocol.""" + _LOGGER.debug( + "(%s) handling event protocol for event: %s", self.__class__.__name__, event + ) + handler = getattr(self, f"handle_{event.event.replace(' ', '_')}", None) + if handler is None: + _LOGGER.warning("Received unknown event: %s", event) + return + if iscoroutinefunction(handler): + task = asyncio.create_task(handler(event)) + self._event_tasks.append(task) + task.add_done_callback(self._event_tasks.remove) + else: + handler(event) From fffc410522d0b965176a9a0458e565a0a229329c Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:47:16 -0400 Subject: [PATCH 6/6] Reserve capacity for high priority requests (#1635) * Remove unused data structures * Implement a `RequestLimiter` type * Use it instead of the semaphore * Add `cancel_waiting` * Fixes * Remove old semaphore * Revert "Remove unused data structures" This reverts commit 71e095aa9fad57df0f49cd1c6162d187c5bf11b3. * WIP * Get things running * Fix tests * Clean up implementation * WIP: add CPython semaphore unit tests * WIP * WIP: initial commit of CPython source * WIP: minimal API compatibility * WIP: make release synchronous * WIP: get most passing * Fix repr test * Fix remaining tests * Oops * Bring up coverage * Drop unnecessary test * Fix tests for 3.10 * Remove unnecessary code * Add backwards compatible `max_value` property for existing radio libraries * Add a unit test * Gracefully handle concurrency limits incompatible with concurrency fractions * Remove unnecessary error --- tests/test_application.py | 32 +-- tests/test_datastructures.py | 249 +++++++++++++++++++++ tests/test_datastructures_cpython.py | 310 +++++++++++++++++++++++++++ tests/test_device.py | 87 ++++---- zigpy/application.py | 15 +- zigpy/config/__init__.py | 23 ++ zigpy/config/defaults.py | 5 + zigpy/datastructures.py | 233 ++++++++++++++++++++ zigpy/device.py | 23 +- zigpy/util.py | 1 - 10 files changed, 909 insertions(+), 69 deletions(-) create mode 100644 tests/test_datastructures_cpython.py diff --git a/tests/test_application.py b/tests/test_application.py index e53f95ade..3c4ec67bb 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -9,6 +9,7 @@ import zigpy.application import zigpy.config as conf +from zigpy.datastructures import RequestLimiter from zigpy.exceptions import ( DeliveryError, NetworkNotFormed, @@ -714,19 +715,18 @@ async def test_request_concurrency(): peak_concurrency = 0 class SlowApp(App): - async def send_packet(self, packet): + async def _send_packet(self, packet): nonlocal current_concurrency, peak_concurrency - async with self._limit_concurrency(): - current_concurrency += 1 - peak_concurrency = max(peak_concurrency, current_concurrency) + current_concurrency += 1 + peak_concurrency = max(peak_concurrency, current_concurrency) - await asyncio.sleep(0.1) - current_concurrency -= 1 + await asyncio.sleep(0.1) + current_concurrency -= 1 - if packet % 10 == 7: - # Fail randomly - raise DeliveryError("Failure") + if packet % 10 == 7: + # Fail randomly + raise DeliveryError("Failure") app = make_app({conf.CONF_MAX_CONCURRENT_REQUESTS: 16}, app_base=SlowApp) @@ -734,7 +734,11 @@ async def send_packet(self, packet): assert peak_concurrency == 0 await asyncio.gather( - *[app.send_packet(i) for i in range(100)], return_exceptions=True + *[ + app.send_packet(t.ZigbeePacket(priority=t.PacketPriority.HIGH)) + for i in range(100) + ], + return_exceptions=True, ) assert current_concurrency == 0 @@ -1635,7 +1639,9 @@ async def test_packet_capture(app) -> None: async def test_request_priority(app) -> None: - app._concurrent_requests_semaphore.max_value = 1 + app._concurrent_requests_semaphore = RequestLimiter( + max_concurrency=1, capacities={t.PacketPriority.LOW: 1} + ) with patch.object(app, "_send_packet", wraps=app._send_packet) as mock_send_packet: packet_low = Mock(name="LOW", priority=t.PacketPriority.LOW) @@ -1663,7 +1669,9 @@ async def test_request_priority(app) -> None: async def test_request_priority_context_concurrency(app, packet): """Test that request_priority contexts work correctly with concurrent tasks.""" # Limit concurrency to see priority ordering effects - app._concurrent_requests_semaphore.max_value = 1 + app._concurrent_requests_semaphore = RequestLimiter( + max_concurrency=1, capacities={t.PacketPriority.LOW: 1} + ) with patch.object(app, "_send_packet", wraps=app._send_packet) as mock_send: diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 54a6a0018..37fff043b 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,5 @@ import asyncio +import logging from unittest.mock import Mock, patch import pytest @@ -489,3 +490,251 @@ async def test_debouncer_cleaning_bug(): # The queue should be empty assert len(debouncer._queue) == 0 + + +async def test_request_limiter_simple_tier(): + """Test the limiter behaving like a simple semaphore with one tier.""" + limiter = datastructures.RequestLimiter(2, {0: 1}) + results = [] + order = [] + + async def worker(uid): + nonlocal order + order.append(f"trying-{uid}") + async with limiter(priority=0): + order.append(f"acquired-{uid}") + results.append(f"start-{uid}") + await asyncio.sleep(0.05) + results.append(f"end-{uid}") + + tasks = [asyncio.create_task(worker(i)) for i in range(3)] + + # The first two should start immediately + await asyncio.sleep(0) + assert limiter.active_requests == 2 + assert limiter.waiting_requests == 1 + assert "active=2, waiting=1" in repr(limiter) + + # Wait for all to complete + await asyncio.gather(*tasks) + + assert limiter.active_requests == 0 + assert limiter.waiting_requests == 0 + # The exact end order depends on scheduling, but the start order is what matters. + assert results[:2] == ["start-0", "start-1"] + assert results[2:4] == ["end-0", "end-1"] + assert results[4] == "start-2" + assert results[5] == "end-2" + + +async def test_request_limiter_cascading_priority(): + """Test the core cascading priority logic.""" + # 2 slots for low priority, 1 extra for high priority (total 3) + limiter = datastructures.RequestLimiter(3, {0: 2 / 3, 10: 1 / 3}) + events = [] + + async def worker(uid, priority, delay): + nonlocal events + async with limiter(priority=priority): + events.append(f"start-{uid}(p={priority})") + await asyncio.sleep(delay) + events.append(f"end-{uid}(p={priority})") + + # These two should take the low-priority slots + t1 = asyncio.create_task(worker(1, 0, 0.1)) + t2 = asyncio.create_task(worker(2, 0, 0.1)) + + # This one should have to wait + t3 = asyncio.create_task(worker(3, 0, 0.1)) + + await asyncio.sleep(0) + assert limiter.active_requests == 2 + assert limiter.waiting_requests == 1 + assert events == ["start-1(p=0)", "start-2(p=0)"] + + # This high-priority one should be able to acquire the reserved slot + t4 = asyncio.create_task(worker(4, 10, 0.1)) + await asyncio.sleep(0) + assert limiter.active_requests == 3 + assert limiter.waiting_requests == 1 + assert "start-4(p=10)" in events + + # No more slots available for anyone + t5 = asyncio.create_task(worker(5, 10, 0.1)) + await asyncio.sleep(0) + assert limiter.active_requests == 3 + assert limiter.waiting_requests == 2 + + await asyncio.gather(t1, t2, t3, t4, t5) + + assert limiter.active_requests == 0 + assert limiter.waiting_requests == 0 + + +async def test_request_limiter_priority_queueing(): + """Test that waiters are woken up according to their priority.""" + limiter = datastructures.RequestLimiter(1, {0: 1}) + events = [] + + async def worker(uid, priority): + nonlocal events + async with limiter(priority=priority): + events.append(uid) + await asyncio.sleep(0.1) # Hold the lock long enough for others to queue + + # First task takes the only slot + t1 = asyncio.create_task(worker(1, 0)) + await asyncio.sleep(0) # Ensure t1 acquires the lock + assert limiter.active_requests == 1 + + # Queue up waiters with different priorities + t2 = asyncio.create_task(worker(2, 5)) # medium + t3 = asyncio.create_task(worker(3, 0)) # low + t4 = asyncio.create_task(worker(4, 10)) # high + + await asyncio.sleep(0) # Ensure all waiters have tried to acquire + assert limiter.waiting_requests == 3 + + # Wait for all tasks to complete + await asyncio.gather(t1, t2, t3, t4) + + # The waiters should have run in priority order, not submission order + assert events == [1, 4, 2, 3] + + +async def test_request_limiter_cancellation(): + """Test cancelling tasks that are waiting for or holding a slot.""" + limiter = datastructures.RequestLimiter(1, {0: 1}) + + async def holder(): + async with limiter(priority=0): + await asyncio.sleep(60) + + async def waiter(): + async with limiter(priority=0): + pytest.fail("Waiter should not acquire the slot") + + # t1 acquires the only slot + t1 = asyncio.create_task(holder()) + await asyncio.sleep(0) + assert limiter.active_requests == 1 + + # t2 waits in line + t2 = asyncio.create_task(waiter()) + await asyncio.sleep(0) + assert limiter.waiting_requests == 1 + + # Cancel the waiting task + t2.cancel() + with pytest.raises(asyncio.CancelledError): + await t2 + + assert limiter.waiting_requests == 0 + assert limiter.active_requests == 1 + + # Now, cancel the holder task + t1.cancel() + with pytest.raises(asyncio.CancelledError): + await t1 + + assert limiter.active_requests == 0 + assert limiter.waiting_requests == 0 + + # The limiter should be usable again + async with limiter(priority=0): + assert limiter.active_requests == 1 + + +async def test_request_limiter_invalid_priority(): + """Test that using a priority with no allocated tier raises an error.""" + limiter = datastructures.RequestLimiter(1, {10: 1}) + + with pytest.raises( + ValueError, + match="Priority 5 is lower than the lowest known priority 10 and has no allocated capacity", + ): + async with limiter(priority=5): + pytest.fail("This code should not be reached") + + +async def test_request_limiter_highest_priority(): + """Test that the highest priority task is always allowed to run.""" + limiter = datastructures.RequestLimiter(1, {100: 1}) + + async with limiter(priority=100 + 1): + assert limiter.active_requests == 1 + + assert limiter.locked(priority=100) + + +async def test_request_limiter_max_concurrency_property(): + """Test the max_concurrency property getter.""" + limiter = datastructures.RequestLimiter(5, {1: 1.0}) + assert limiter.max_concurrency == 5 + + +async def test_request_limiter_priority_higher_than_known(): + """Test priority higher than the highest known tier.""" + limiter = datastructures.RequestLimiter(2, {5: 0.5, 10: 0.5}) + + # Priority 15 is higher than highest known (10), should use highest tier + async with limiter(priority=15): + assert limiter.active_requests == 1 + + +async def test_request_limiter_cancel_waiting(): + """Test cancelling all waiting tasks.""" + limiter = datastructures.RequestLimiter(1, {1: 1.0}) + + # Acquire the limiter first + async with limiter(priority=1): + # Create waiting tasks + async def waiter(): + await limiter._acquire(priority=1) + + task1 = asyncio.create_task(waiter()) + task2 = asyncio.create_task(waiter()) + await asyncio.sleep(0) # Let tasks start waiting + + assert limiter.waiting_requests == 2 + + # Cancel all waiting tasks + exc = RuntimeError("Test cancellation") + limiter.cancel_waiting(exc) + + # Tasks should be cancelled with the provided exception + with pytest.raises(RuntimeError, match="Test cancellation"): + await task1 + with pytest.raises(RuntimeError, match="Test cancellation"): + await task2 + + +async def test_request_limiter_backwards_compatibility() -> None: + """Test `max_value` property proxying `max_concurrency`.""" + limiter = datastructures.RequestLimiter(5, {1: 1.0}) + + with pytest.deprecated_call(): + assert limiter.max_value == 5 + + with pytest.deprecated_call(): + limiter.max_value = 10 + + assert limiter.max_concurrency == 10 + + with pytest.deprecated_call(): + assert limiter.max_value == 10 + + +async def test_request_limiter_adjust_max_concurrency( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test adjusting max concurrency at runtime snaps to a reasonable value.""" + limiter = datastructures.RequestLimiter(8, {1: 1 / 3, 2: 1 / 2, 3: 1 / 6}) + + assert limiter.max_concurrency == 12 + + with caplog.at_level(logging.WARNING): + limiter.max_concurrency = 13 + + # It'll be bumped to 18, which is divisible by 3, 2, and 6 + assert limiter.max_concurrency == 18 diff --git a/tests/test_datastructures_cpython.py b/tests/test_datastructures_cpython.py new file mode 100644 index 000000000..9662e3e9f --- /dev/null +++ b/tests/test_datastructures_cpython.py @@ -0,0 +1,310 @@ +"""CPython semaphore unit tests minimally modified to work with RequestLimiter.""" + +# ruff: noqa: PT009, PT027 + +import asyncio +import re +import unittest + +from zigpy.datastructures import RequestLimiter + +STR_RGX_REPR = ( + r"^<(?P.*?) object at (?P
.*?)" + r"\[(?P" + r"(set|unset|locked|unlocked|filling|draining|resetting|broken)" + r"(, value:\d)?" + r"(, waiters:\d+)?" + r"(, waiters:\d+\/\d+)?" # barrier + r")\]>\Z" +) +RGX_REPR = re.compile(STR_RGX_REPR) + + +class SemaphoreTests(unittest.IsolatedAsyncioTestCase): + def test_initial_value_zero(self): + sem = RequestLimiter(max_concurrency=0, capacities={1: 1.0}) + self.assertTrue(sem.locked(priority=1)) + + async def test_repr(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + # RequestLimiter format: + self.assertIn("max_concurrency=1", repr(sem)) + self.assertIn("active=0", repr(sem)) + self.assertIn("waiting=0", repr(sem)) + + await sem._acquire(priority=1) + self.assertIn("active=1", repr(sem)) + self.assertIn("waiting=0", repr(sem)) + + # Start tasks that will wait since semaphore is already acquired + task1 = asyncio.create_task(sem._acquire(priority=1)) + await asyncio.sleep(0) # Let task1 start and get queued + self.assertIn("waiting=1", repr(sem)) + + task2 = asyncio.create_task(sem._acquire(priority=1)) + await asyncio.sleep(0) # Let task2 start and get queued + self.assertIn("waiting=2", repr(sem)) + + # Clean up + task1.cancel() + task2.cancel() + + await asyncio.gather(task1, task2, return_exceptions=True) + + async def test_semaphore(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + # self.assertEqual(1, sem.active_requests) + + with self.assertRaisesRegex( + TypeError, + "object RequestLimiter can't be used in 'await' expression", + ): + await sem + + self.assertFalse(sem.locked(priority=1)) + # self.assertEqual(1, sem.active_requests) + + def test_semaphore_value(self): + self.assertRaises(ValueError, RequestLimiter, -1, {}) + + async def test_acquire(self): + sem = RequestLimiter(max_concurrency=3, capacities={1: 1.0}) + result = [] + + self.assertTrue(await sem._acquire(priority=1)) + self.assertTrue(await sem._acquire(priority=1)) + self.assertFalse(sem.locked(priority=1)) + + async def c1(result): + await sem._acquire(priority=1) + result.append(1) + return True + + async def c2(result): + await sem._acquire(priority=1) + result.append(2) + return True + + async def c3(result): + await sem._acquire(priority=1) + result.append(3) + return True + + async def c4(result): + await sem._acquire(priority=1) + result.append(4) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(sem.locked(priority=1)) + self.assertEqual(2, sem.waiting_requests) + # self.assertEqual(0, sem.active_requests) + + t4 = asyncio.create_task(c4(result)) + + sem._release(priority=1) + sem._release(priority=1) + # self.assertEqual(3, sem.active_requests) + + await asyncio.sleep(0) + # self.assertEqual(0, sem.active_requests) + self.assertEqual(3, len(result)) + self.assertTrue(sem.locked(priority=1)) + self.assertEqual(1, sem.waiting_requests) + # self.assertEqual(0, sem.active_requests) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + race_tasks = [t2, t3, t4] + done_tasks = [t for t in race_tasks if t.done() and t.result()] + self.assertEqual(2, len(done_tasks)) + + # cleanup locked semaphore + sem._release(priority=1) + await asyncio.gather(*race_tasks) + + async def test_acquire_cancel(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + await sem._acquire(priority=1) + + acquire = asyncio.create_task(sem._acquire(priority=1)) + asyncio.get_running_loop().call_soon(acquire.cancel) + with self.assertRaises(asyncio.CancelledError): + await acquire + self.assertTrue( + (not sem._waiters) or all(waiter.done() for waiter in sem._waiters) + ) + + async def test_acquire_cancel_before_awoken(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + + t1 = asyncio.create_task(sem._acquire(priority=1)) + t2 = asyncio.create_task(sem._acquire(priority=1)) + t3 = asyncio.create_task(sem._acquire(priority=1)) + t4 = asyncio.create_task(sem._acquire(priority=1)) + + await asyncio.sleep(0) + + t1.cancel() + t2.cancel() + sem._release(priority=1) + + await asyncio.sleep(0) + await asyncio.sleep(0) + num_done = sum(t.done() for t in [t3, t4]) + self.assertEqual(num_done, 1) + self.assertTrue(t3.done()) + self.assertFalse(t4.done()) + + t3.cancel() + t4.cancel() + await asyncio.sleep(0) + + async def test_acquire_hang(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + + t1 = asyncio.create_task(sem._acquire(priority=1)) + t2 = asyncio.create_task(sem._acquire(priority=1)) + await asyncio.sleep(0) + + t1.cancel() + sem._release(priority=1) + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertTrue(sem.locked(priority=1)) + self.assertTrue(t2.done()) + + async def test_acquire_no_hang(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + + async def c1(): + async with sem(priority=1): + await asyncio.sleep(0) + t2.cancel() + + async def c2(): + async with sem(priority=1): + self.assertFalse(True) + + t1 = asyncio.create_task(c1()) + t2 = asyncio.create_task(c2()) + + r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True) + self.assertTrue(r1 is None) + self.assertTrue(isinstance(r2, asyncio.CancelledError)) + + await asyncio.wait_for(sem._acquire(priority=1), timeout=1.0) + + def test_release_not_acquired(self): + sem = asyncio.BoundedSemaphore() + + self.assertRaises(ValueError, sem.release) + + async def test_release_no_waiters(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + await sem._acquire(priority=1) + self.assertTrue(sem.locked(priority=1)) + + sem._release(priority=1) + self.assertFalse(sem.locked(priority=1)) + + async def test_acquire_fifo_order(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + result = [] + + async def coro(tag): + await sem._acquire(priority=1) + result.append(f"{tag}_1") + await asyncio.sleep(0.01) + sem._release(priority=1) + + await sem._acquire(priority=1) + result.append(f"{tag}_2") + await asyncio.sleep(0.01) + sem._release(priority=1) + + tasks = [] + tasks.append(asyncio.create_task(coro("c1"))) + tasks.append(asyncio.create_task(coro("c2"))) + tasks.append(asyncio.create_task(coro("c3"))) + await asyncio.gather(*tasks, return_exceptions=True) + + self.assertEqual(["c1_1", "c2_1", "c3_1", "c1_2", "c2_2", "c3_2"], result) + + async def test_acquire_fifo_order_2(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + result = [] + + async def c1(result): + await sem._acquire(priority=1) + result.append(1) + return True + + async def c2(result): + await sem._acquire(priority=1) + result.append(2) + sem._release(priority=1) + await sem._acquire(priority=1) + result.append(4) + return True + + async def c3(result): + await sem._acquire(priority=1) + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + + sem._release(priority=1) + sem._release(priority=1) + + tasks = [t1, t2, t3] + await asyncio.gather(*tasks) + # self.assertEqual([1, 2, 3, 4], result) + self.assertEqual([1, 2, 4, 3], result) # We differ here + + async def test_acquire_fifo_order_3(self): + sem = RequestLimiter(max_concurrency=1, capacities={1: 1.0}) + result = [] + + async def c1(result): + await sem._acquire(priority=1) + result.append(1) + return True + + async def c2(result): + await sem._acquire(priority=1) + result.append(2) + return True + + async def c3(result): + await sem._acquire(priority=1) + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + + t1.cancel() + + await asyncio.sleep(0) + + sem._release(priority=1) + sem._release(priority=1) + + tasks = [t1, t2, t3] + await asyncio.gather(*tasks, return_exceptions=True) + # self.assertEqual([2, 3], result) + self.assertEqual([1, 2, 3], result) # We differ here diff --git a/tests/test_device.py b/tests/test_device.py index 560a98de2..434f20616 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import AsyncExitStack from datetime import datetime, timezone import logging import math @@ -8,6 +9,7 @@ from zigpy import device, endpoint import zigpy.application +from zigpy.datastructures import RequestLimiter import zigpy.exceptions from zigpy.ota import OtaImagesResult import zigpy.ota.image @@ -1278,7 +1280,9 @@ async def test_debouncing(dev): async def test_device_concurrency(dev: device.Device) -> None: """Test that the device can handle multiple requests concurrently.""" - dev._concurrent_requests_semaphore.max_value = 1 + dev._concurrent_requests_semaphore = RequestLimiter( + max_concurrency=1, capacities={t.PacketPriority.LOW: 1} + ) ep = dev.add_endpoint(1) ep.add_input_cluster(Basic.cluster_id) @@ -1363,7 +1367,7 @@ async def delayed_receive(*args, **kwargs) -> None: raise asyncio.TimeoutError() dev._application.request = AsyncMock(side_effect=delayed_receive) - dev._concurrent_requests_semaphore.max_value = 100000 + dev._concurrent_requests_semaphore.max_concurrency = 100000 # We send 256 + 1 requests errors = await asyncio.gather( @@ -1471,7 +1475,9 @@ async def test_poll_control_checkin_callback( expected_fast_poll: bool, ) -> None: """Test PollControl check-in callback with different device states.""" - dev._concurrent_requests_semaphore.max_value = 1 + dev._concurrent_requests_semaphore = RequestLimiter( + max_concurrency=1, capacities={t.PacketPriority.LOW: 1} + ) ep = dev.add_endpoint(1) poll_control = ep.add_input_cluster(PollControl.cluster_id) @@ -1486,46 +1492,45 @@ async def test_poll_control_checkin_callback( else: dev._initialize_task = None - if semaphore_locked: - await dev._concurrent_requests_semaphore.acquire() - - zcl_hdr = foundation.ZCLHeader( - frame_control=foundation.FrameControl( - frame_type=foundation.FrameType.CLUSTER_COMMAND, - is_manufacturer_specific=False, - direction=foundation.Direction.Server_to_Client, - disable_default_response=1, - reserved=0, - ), - tsn=0x12, - command_id=PollControl.ClientCommandDefs.checkin.id, - ) - command = PollControl.ClientCommandDefs.checkin.schema() - - # Test the callback - await dev.poll_control_checkin_callback(zcl_hdr, command) - - # Verify the correct response was sent - if expected_fast_poll: - assert poll_control.checkin_response.mock_calls == [ - call( - start_fast_polling=expected_fast_poll, - fast_poll_timeout=int(device.DEFAULT_FAST_POLL_TIMEOUT * 4), - tsn=0x12, - ) - ] - else: - assert poll_control.checkin_response.mock_calls == [ - call( - start_fast_polling=expected_fast_poll, - fast_poll_timeout=0, - tsn=0x12, + async with AsyncExitStack() as stack: + if semaphore_locked: + await stack.enter_async_context( + dev._concurrent_requests_semaphore(priority=t.PacketPriority.LOW) ) - ] - # Clean up semaphore if we acquired it - if semaphore_locked: - dev._concurrent_requests_semaphore.release() + zcl_hdr = foundation.ZCLHeader( + frame_control=foundation.FrameControl( + frame_type=foundation.FrameType.CLUSTER_COMMAND, + is_manufacturer_specific=False, + direction=foundation.Direction.Server_to_Client, + disable_default_response=1, + reserved=0, + ), + tsn=0x12, + command_id=PollControl.ClientCommandDefs.checkin.id, + ) + command = PollControl.ClientCommandDefs.checkin.schema() + + # Test the callback + await dev.poll_control_checkin_callback(zcl_hdr, command) + + # Verify the correct response was sent + if expected_fast_poll: + assert poll_control.checkin_response.mock_calls == [ + call( + start_fast_polling=expected_fast_poll, + fast_poll_timeout=int(device.DEFAULT_FAST_POLL_TIMEOUT * 4), + tsn=0x12, + ) + ] + else: + assert poll_control.checkin_response.mock_calls == [ + call( + start_fast_polling=expected_fast_poll, + fast_poll_timeout=0, + tsn=0x12, + ) + ] async def test_begin_fast_polling_with_cluster(dev: device.Device) -> None: diff --git a/zigpy/application.py b/zigpy/application.py index 1835816d1..5b5c72072 100644 --- a/zigpy/application.py +++ b/zigpy/application.py @@ -29,7 +29,7 @@ import zigpy.backups import zigpy.config as conf from zigpy.const import INTERFERENCE_MESSAGE -from zigpy.datastructures import PriorityDynamicBoundedSemaphore +from zigpy.datastructures import RequestLimiter import zigpy.device import zigpy.endpoint import zigpy.exceptions @@ -79,8 +79,9 @@ def __init__(self, config: dict) -> None: self._watchdog_task: asyncio.Task | None = None - self._concurrent_requests_semaphore = PriorityDynamicBoundedSemaphore( - self._config[conf.CONF_MAX_CONCURRENT_REQUESTS] + self._concurrent_requests_semaphore = RequestLimiter( + max_concurrency=self._config[conf.CONF_MAX_CONCURRENT_REQUESTS], + capacities=self._config[conf.CONF_EXPERIMENTAL][conf.CONF_CONCURRENCY], ) self.ota = zigpy.ota.OTA(self._config[conf.CONF_OTA], self) @@ -784,19 +785,19 @@ async def _limit_concurrency( LOGGER.debug( "Critical priority request received (%s), skipping queue with %d requests", priority, - self._concurrent_requests_semaphore.num_waiting, + self._concurrent_requests_semaphore.waiting_requests, ) manager = nullcontext() was_locked = False else: manager = self._concurrent_requests_semaphore(priority=priority) - was_locked = self._concurrent_requests_semaphore.locked() + was_locked = self._concurrent_requests_semaphore.locked(priority=priority) if was_locked: LOGGER.debug( "Max concurrency (%s) reached, delaying request (%s enqueued)", - self._concurrent_requests_semaphore.max_value, - self._concurrent_requests_semaphore.num_waiting, + self._concurrent_requests_semaphore.active_requests, + self._concurrent_requests_semaphore.waiting_requests, ) async with manager: diff --git a/zigpy/config/__init__.py b/zigpy/config/__init__.py index e3bb0dfc3..f28c96123 100644 --- a/zigpy/config/__init__.py +++ b/zigpy/config/__init__.py @@ -5,6 +5,7 @@ import voluptuous as vol from zigpy.config.defaults import ( + CONF_CONCURRENCY_DEFAULT, CONF_DEVICE_BAUDRATE_DEFAULT, CONF_DEVICE_FLOW_CONTROL_DEFAULT, CONF_MAX_CONCURRENT_REQUESTS_DEFAULT, @@ -91,6 +92,8 @@ CONF_TOPO_SCAN_ENABLED = "topology_scan_enabled" CONF_TOPO_SKIP_COORDINATOR = "topology_scan_skip_coordinator" CONF_WATCHDOG_ENABLED = "watchdog_enabled" +CONF_EXPERIMENTAL = "experimental" +CONF_CONCURRENCY = "concurrency" CONF_OTA_ALLOW_ADVANCED_DIR_STRING = ( "I understand I can *destroy* my devices by enabling OTA updates from files." @@ -343,6 +346,25 @@ {**SCHEMA_OTA_BASE, **SCHEMA_OTA_DEPRECATED}, extra=vol.ALLOW_EXTRA ) +SCHEMA_EXPERIMENTAL = vol.Schema( + { + vol.Optional(CONF_CONCURRENCY, default=CONF_CONCURRENCY_DEFAULT): vol.All( + vol.Schema( + { + vol.Required(name.lower()): vol.All( + float, + vol.Range(min=0, max=1, min_included=False, max_included=False), + ) + for name, value in t.PacketPriority.__members__.items() + if value != t.PacketPriority.CRITICAL + } + ), + # Coerce the key types + lambda d: {t.PacketPriority[k.upper()]: v for k, v in d.items()}, + ), + } +) + ZIGPY_SCHEMA = vol.Schema( { vol.Optional(CONF_DATABASE, default=None): vol.Any(None, str), @@ -379,6 +401,7 @@ vol.Optional( CONF_WATCHDOG_ENABLED, default=CONF_WATCHDOG_ENABLED_DEFAULT ): cv_boolean, + vol.Optional(CONF_EXPERIMENTAL, default={}): SCHEMA_EXPERIMENTAL, }, extra=vol.ALLOW_EXTRA, ) diff --git a/zigpy/config/defaults.py b/zigpy/config/defaults.py index 44bc24553..2ff28b295 100644 --- a/zigpy/config/defaults.py +++ b/zigpy/config/defaults.py @@ -50,3 +50,8 @@ CONF_TOPO_SCAN_ENABLED_DEFAULT = True CONF_TOPO_SKIP_COORDINATOR_DEFAULT = False CONF_WATCHDOG_ENABLED_DEFAULT = True +CONF_CONCURRENCY_DEFAULT = { + t.PacketPriority.HIGH.name.lower(): 0.25, + t.PacketPriority.NORMAL.name.lower(): 0.50, + t.PacketPriority.LOW.name.lower(): 0.25, +} diff --git a/zigpy/datastructures.py b/zigpy/datastructures.py index fecac89d4..a3d94d325 100644 --- a/zigpy/datastructures.py +++ b/zigpy/datastructures.py @@ -4,10 +4,18 @@ import asyncio import bisect +import collections import contextlib +from fractions import Fraction import functools +import heapq +import logging +import math import types import typing +import warnings + +_LOGGER = logging.getLogger(__name__) class WrappedContextManager: @@ -309,3 +317,228 @@ def filter(self, obj: typing.Any, expire_in: float) -> bool: def __repr__(self) -> str: """String representation of the debouncer.""" return f"<{self.__class__.__name__} [tracked:{len(self._queue)}]>" + + +class _LimiterContext: + """Helper class to manage the async context for the RequestLimiter.""" + + def __init__(self, limiter: RequestLimiter, priority: int) -> None: + self._limiter = limiter + self._priority = priority + + async def __aenter__(self) -> None: + """Acquire a slot from the limiter.""" + await self._limiter._acquire(self._priority) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + """Release the slot back to the limiter.""" + self._limiter._release(self._priority) + + +class RequestLimiter: + """Limits concurrent requests with cascading capacity for multiple priority levels.""" + + def __init__(self, max_concurrency: int, capacities: dict[int, float]) -> None: + """Initializes the RequestLimiter.""" + if max_concurrency < 0: + raise ValueError(f"max_concurrency must be >= 0: {max_concurrency}") + + self._lock = asyncio.Lock() + self._capacity_fractions = capacities + self._sorted_priorities = sorted(capacities.keys()) + + self._cumulative_capacity: dict[int, int] = {} + self._max_concurrency = max_concurrency + self._recalculate_capacity() + + self._active_requests_by_tier: typing.Counter[int] = collections.Counter() + self._waiters: list[tuple[int, int, asyncio.Future]] = [] + self._comparison_counter = 0 + + @property + def active_requests(self) -> int: + """Returns the total number of currently running requests.""" + return sum(self._active_requests_by_tier.values()) + + @property + def waiting_requests(self) -> int: + """Returns the number of requests waiting for a slot.""" + return len(self._waiters) + + @property + def max_concurrency(self) -> int: + """Returns the maximum concurrency of the limiter.""" + return self._max_concurrency + + @max_concurrency.setter + def max_concurrency(self, new_value: int) -> None: + """Updates the maximum concurrency of the limiter.""" + self._max_concurrency = new_value + self._recalculate_capacity() + self._wake_waiters() + + @property + def max_value(self) -> None: + """Deprecated alias for `max_concurrency`.""" + warnings.warn( + "`max_value` is deprecated, use `max_concurrency` instead", + DeprecationWarning, + stacklevel=2, + ) + return self.max_concurrency + + @max_value.setter + def max_value(self, new_value: int) -> None: + """Deprecated setter alias for `max_concurrency`.""" + warnings.warn( + "`max_value` is deprecated, use `max_concurrency` instead", + DeprecationWarning, + stacklevel=2, + ) + self.max_concurrency = new_value + + def _recalculate_capacity(self) -> None: + # Assume that all of the fractions are simple + divisors = [ + Fraction.from_float(f).limit_denominator(10).denominator + for f in self._capacity_fractions.values() + ] + + lcm = math.lcm(*divisors) + + if self._max_concurrency % lcm != 0: + next_best = lcm - self._max_concurrency % lcm + assert next_best > 0 + + _LOGGER.warning( + "Requested adapter concurrency %d is not compatible with priority fractions %r. Increasing concurrency to %d.", + self._max_concurrency, + self._capacity_fractions, + self._max_concurrency + next_best, + ) + self._max_concurrency += next_best + + cumulative_capacity = 0 + + for priority in self._sorted_priorities: + portion = self._capacity_fractions[priority] * self._max_concurrency + cumulative_capacity += round(portion) + + self._cumulative_capacity[priority] = cumulative_capacity + + def __call__(self, priority: int = 0) -> _LimiterContext: + """Returns an async context manager to safely acquire and release a slot.""" + return _LimiterContext(self, priority) + + def _get_effective_priority_tier(self, priority: int) -> int: + """Finds the capacity tier that the given priority falls into.""" + if priority < self._sorted_priorities[0]: + raise ValueError( + f"Priority {priority} is lower than the lowest known priority " + f"{self._sorted_priorities[0]} and has no allocated capacity" + ) + + idx = bisect.bisect_right(self._sorted_priorities, priority) + return self._sorted_priorities[idx - 1] + + def locked(self, priority: int) -> bool: + """Checks if a request with a given priority can run.""" + effective_tier = self._get_effective_priority_tier(priority) + limit = self._cumulative_capacity[effective_tier] + waiting_requests = 0 + + if limit == 0: + return True + + for tier, count in self._active_requests_by_tier.items(): + if tier > effective_tier: + continue + + waiting_requests += count + if waiting_requests >= limit: + return True + + return False + + def _wake_waiters(self) -> None: + """Wakes up any waiting tasks that can now run.""" + while self._waiters: + waiter_priority, _, fut = self._waiters[0] + priority = -waiter_priority # We flip the sign when storing, for comparison + + if not self.locked(priority): + heapq.heappop(self._waiters) + if not fut.done(): + effective_tier = self._get_effective_priority_tier(priority) + self._active_requests_by_tier[effective_tier] += 1 + fut.set_result(None) + else: + break + + async def _acquire(self, priority: int = 0) -> bool: + """Acquires a slot in the limiter, waiting if necessary.""" + effective_tier = self._get_effective_priority_tier(priority) + + # A task can run immediately if it has capacity AND it has a higher + # priority than any task already waiting. This allows high-priority + # tasks to jump the queue, while maintaining FIFO for tasks of the + # same priority. + highest_waiter_priority = ( + -self._waiters[0][0] if self._waiters else -float("inf") + ) + + if not self.locked(priority) and priority > highest_waiter_priority: + self._active_requests_by_tier[effective_tier] += 1 + return True + + # To ensure that our objects don't have to be themselves comparable, we + # maintain a global count and increment it on every insert. This way, + # the tuple `(-priority, count, item)` will never have to compare `item`. + self._comparison_counter += 1 + fut = asyncio.get_running_loop().create_future() + waiter_obj = (-priority, self._comparison_counter, fut) + heapq.heappush(self._waiters, waiter_obj) + + try: + try: + await fut + return True + finally: + if waiter_obj in self._waiters: + self._waiters.remove(waiter_obj) + heapq.heapify(self._waiters) + except asyncio.CancelledError: + if fut.done() and not fut.cancelled(): + self._active_requests_by_tier[effective_tier] -= 1 + + raise + finally: + self._wake_waiters() + + def _release(self, priority: int = 0) -> None: + """Releases an acquired slot back to the limiter.""" + effective_tier = self._get_effective_priority_tier(priority) + assert self._active_requests_by_tier[effective_tier] > 0 + self._active_requests_by_tier[effective_tier] -= 1 + self._wake_waiters() + + def cancel_waiting(self, exc: BaseException) -> None: + """Cancel all waiters with the given exception.""" + for _, _, fut in self._waiters: + if not fut.done(): + fut.set_exception(exc) + + def __repr__(self) -> str: + """Provides a string representation of the limiter's state.""" + return ( + f"<{self.__class__.__name__}(" + f"max_concurrency={self._max_concurrency}" + f", active={self.active_requests}" + f", waiting={self.waiting_requests}" + f")>" + ) diff --git a/zigpy/device.py b/zigpy/device.py index b68ec5b40..316cb4f61 100644 --- a/zigpy/device.py +++ b/zigpy/device.py @@ -126,8 +126,13 @@ def __init__(self, application: ControllerApplication, ieee: t.EUI64, nwk: t.NWK self._tasks: set[asyncio.Future[Any]] = set() self._packet_debouncer = zigpy.datastructures.Debouncer() - self._concurrent_requests_semaphore = ( - zigpy.datastructures.PriorityDynamicBoundedSemaphore(MAX_DEVICE_CONCURRENCY) + self._concurrent_requests_semaphore = zigpy.datastructures.RequestLimiter( + max_concurrency=MAX_DEVICE_CONCURRENCY, + capacities={ + t.PacketPriority.HIGH: 0.5, + # t.PacketPriority.NORMAL is shared with LOW + t.PacketPriority.LOW: 0.5, + }, ) # Retained for backwards compatibility, will be removed in a future release @@ -179,19 +184,19 @@ async def _limit_concurrency(self, *, priority: int | None = None): LOGGER.debug( "Critical priority request received (%s), skipping queue with %d requests", priority, - self._concurrent_requests_semaphore.num_waiting, + self._concurrent_requests_semaphore.waiting_requests, ) manager = nullcontext() was_locked = False else: manager = self._concurrent_requests_semaphore(priority=priority) - was_locked = self._concurrent_requests_semaphore.locked() + was_locked = self._concurrent_requests_semaphore.locked(priority=priority) if was_locked: LOGGER.debug( - "Device concurrency (%s) reached, delaying device request (%s enqueued)", - self._concurrent_requests_semaphore.max_value, - self._concurrent_requests_semaphore.num_waiting, + "Device concurrency (%s) reached, delaying request (%s enqueued)", + self._concurrent_requests_semaphore.active_requests, + self._concurrent_requests_semaphore.waiting_requests, ) async with manager: @@ -350,9 +355,11 @@ async def poll_control_checkin_callback( # to be sent if ( self.initializing - or self._concurrent_requests_semaphore.locked() + or self._concurrent_requests_semaphore.active_requests > 0 or self._fast_polling ): + # Initiate fast polling mode if we are initializing or waiting for + # requests to be sent await poll_control.checkin_response( start_fast_polling=True, fast_poll_timeout=int(DEFAULT_FAST_POLL_TIMEOUT * 4), diff --git a/zigpy/util.py b/zigpy/util.py index d5b65009e..30e603776 100644 --- a/zigpy/util.py +++ b/zigpy/util.py @@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.modes import ECB -from zigpy.datastructures import DynamicBoundedSemaphore # noqa: F401 from zigpy.exceptions import ZigbeeException import zigpy.types as t